Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#381 Feature/sg 000 connect to lab

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/sg-000_connect_to_lab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
  1. # Copyright (c) Megvii, Inc. and its affiliates.
  2. # Apache 2.0 license: https://github.com/Megvii-BaseDetection/YOLOX/blob/main/LICENSE
  3. import itertools
  4. from typing import Optional
  5. import torch
  6. import torch.distributed as dist
  7. from torch.utils.data.sampler import Sampler
  8. class InfiniteSampler(Sampler):
  9. """
  10. In training, we only care about the "infinite stream" of training data.
  11. So this sampler produces an infinite stream of indices and
  12. all workers cooperate to correctly shuffle the indices and sample different indices.
  13. The samplers in each worker effectively produces `indices[worker_id::num_workers]`
  14. where `indices` is an infinite stream of indices consisting of
  15. `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
  16. or `range(size) + range(size) + ...` (if shuffle is False)
  17. """
  18. def __init__(
  19. self,
  20. dataset,
  21. shuffle: bool = True,
  22. seed: Optional[int] = 0,
  23. rank=0,
  24. world_size=1,
  25. ):
  26. """
  27. Args:
  28. size (int): the total number of data of the underlying dataset to sample from
  29. shuffle (bool): whether to shuffle the indices or not
  30. seed (int): the initial seed of the shuffle. Must be the same
  31. across all workers. If None, will use a random seed shared
  32. among workers (require synchronization among all workers).
  33. """
  34. self._size = len(dataset)
  35. assert len(dataset) > 0
  36. self._shuffle = shuffle
  37. self._seed = int(seed)
  38. if dist.is_available() and dist.is_initialized():
  39. self._rank = dist.get_rank()
  40. self._world_size = dist.get_world_size()
  41. else:
  42. self._rank = rank
  43. self._world_size = world_size
  44. def __iter__(self):
  45. start = self._rank
  46. yield from itertools.islice(
  47. self._infinite_indices(), start, None, self._world_size
  48. )
  49. def _infinite_indices(self):
  50. g = torch.Generator()
  51. g.manual_seed(self._seed)
  52. while True:
  53. if self._shuffle:
  54. yield from torch.randperm(self._size, generator=g)
  55. else:
  56. yield from torch.arange(self._size)
  57. def __len__(self):
  58. return self._size // self._world_size
Discard
Tip!

Press p or to see the previous file or, n or to see the next file