Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#883 Remove InfiniteSampler

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-838-remove-infinite-sampler
@@ -167,7 +167,6 @@ class LRWarmups:
 class Samplers:
 class Samplers:
     """Static class to hold all the supported Samplers names"""
     """Static class to hold all the supported Samplers names"""
 
 
-    INFINITE = "InfiniteSampler"
     REPEAT_AUG = "RepeatAugSampler"
     REPEAT_AUG = "RepeatAugSampler"
     DISTRIBUTED = "DistributedSampler"
     DISTRIBUTED = "DistributedSampler"
     SEQUENTIAL = "SequentialSampler"
     SEQUENTIAL = "SequentialSampler"
Discard
@@ -1,7 +1,6 @@
-from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler
 from super_gradients.training.datasets.samplers.repeated_augmentation_sampler import RepeatAugSampler
 from super_gradients.training.datasets.samplers.repeated_augmentation_sampler import RepeatAugSampler
 from super_gradients.common.object_names import Samplers
 from super_gradients.common.object_names import Samplers
 from super_gradients.common.registry.registry import SAMPLERS
 from super_gradients.common.registry.registry import SAMPLERS
 
 
 
 
-__all__ = ["SAMPLERS", "Samplers", "InfiniteSampler", "RepeatAugSampler"]
+__all__ = ["SAMPLERS", "Samplers", "RepeatAugSampler"]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
  1. # Copyright (c) Megvii, Inc. and its affiliates.
  2. # Apache 2.0 license: https://github.com/Megvii-BaseDetection/YOLOX/blob/main/LICENSE
  3. import itertools
  4. from typing import Optional
  5. import torch
  6. import torch.distributed as dist
  7. from torch.utils.data.sampler import Sampler
  8. from deprecate import deprecated
  9. from super_gradients.common.object_names import Samplers
  10. from super_gradients.common.registry.registry import register_sampler
  11. @register_sampler(Samplers.INFINITE)
  12. class InfiniteSampler(Sampler):
  13. """
  14. In training, we only care about the "infinite stream" of training data.
  15. So this sampler produces an infinite stream of indices and
  16. all workers cooperate to correctly shuffle the indices and sample different indices.
  17. The samplers in each worker effectively produces `indices[worker_id::num_workers]`
  18. where `indices` is an infinite stream of indices consisting of
  19. `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
  20. or `range(size) + range(size) + ...` (if shuffle is False)
  21. """
  22. @deprecated(target=None, deprecated_in="3.0.8", remove_in="3.1.0")
  23. def __init__(
  24. self,
  25. dataset,
  26. shuffle: bool = True,
  27. seed: Optional[int] = 0,
  28. rank=0,
  29. world_size=1,
  30. ):
  31. """
  32. :param size: Total number of data of the underlying dataset to sample from
  33. :param shuffle: Whether to shuffle the indices or not
  34. :param seed: Initial seed of the shuffle. Must be the same across all workers.
  35. If None, will use a random seed shared among workers (require synchronization among all workers).
  36. """
  37. self._size = len(dataset)
  38. assert len(dataset) > 0
  39. self._shuffle = shuffle
  40. self._seed = int(seed)
  41. if dist.is_available() and dist.is_initialized():
  42. self._rank = dist.get_rank()
  43. self._world_size = dist.get_world_size()
  44. else:
  45. self._rank = rank
  46. self._world_size = world_size
  47. def __iter__(self):
  48. start = self._rank
  49. yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
  50. def _infinite_indices(self):
  51. g = torch.Generator()
  52. g.manual_seed(self._seed)
  53. while True:
  54. if self._shuffle:
  55. yield from torch.randperm(self._size, generator=g)
  56. else:
  57. yield from torch.arange(self._size)
  58. def __len__(self):
  59. return self._size // self._world_size
Discard
@@ -32,7 +32,7 @@ from super_gradients.common.factories.losses_factory import LossesFactory
 from super_gradients.common.factories.metrics_factory import MetricsFactory
 from super_gradients.common.factories.metrics_factory import MetricsFactory
 
 
 from super_gradients.training import utils as core_utils, models, dataloaders
 from super_gradients.training import utils as core_utils, models, dataloaders
-from super_gradients.training.datasets.samplers import InfiniteSampler, RepeatAugSampler
+from super_gradients.training.datasets.samplers import RepeatAugSampler
 from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat
 from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat
 from super_gradients.training.metrics.metric_utils import (
 from super_gradients.training.metrics.metric_utils import (
     get_metrics_titles,
     get_metrics_titles,
@@ -160,7 +160,6 @@ class Trainer:
         self.strict_load = StrictLoad.ON
         self.strict_load = StrictLoad.ON
         self.load_ema_as_net = False
         self.load_ema_as_net = False
         self.ckpt_best_name = "ckpt_best.pth"
         self.ckpt_best_name = "ckpt_best.pth"
-        self._infinite_train_loader = False
         self._first_backward = True
         self._first_backward = True
 
 
         # METRICS
         # METRICS
@@ -461,11 +460,7 @@ class Trainer:
             progress_bar_train_loader.set_postfix(**pbar_message_dict)
             progress_bar_train_loader.set_postfix(**pbar_message_dict)
             self.phase_callback_handler.on_train_batch_end(context)
             self.phase_callback_handler.on_train_batch_end(context)
 
 
-            # TODO: ITERATE BY MAX ITERS
-            # FOR INFINITE SAMPLERS WE MUST BREAK WHEN REACHING LEN ITERATIONS.
-            if (self._infinite_train_loader and batch_idx == len(self.train_loader) - 1) or (
-                self.max_train_batches is not None and self.max_train_batches - 1 <= batch_idx
-            ):
+            if self.max_train_batches is not None and self.max_train_batches - 1 <= batch_idx:
                 break
                 break
 
 
         self.train_monitored_values = sg_trainer_utils.update_monitored_values_dict(
         self.train_monitored_values = sg_trainer_utils.update_monitored_values_dict(
@@ -1022,10 +1017,10 @@ class Trainer:
                     "You are using a SequentialSampler on you training dataloader, while working on DDP. "
                     "You are using a SequentialSampler on you training dataloader, while working on DDP. "
                     "This cancels the DDP benefits since it makes each process iterate through the entire dataset"
                     "This cancels the DDP benefits since it makes each process iterate through the entire dataset"
                 )
                 )
-            if not isinstance(train_sampler, (DistributedSampler, InfiniteSampler, RepeatAugSampler)):
+            if not isinstance(train_sampler, (DistributedSampler, RepeatAugSampler)):
                 logger.warning(
                 logger.warning(
                     "The training sampler you are using might not support DDP. "
                     "The training sampler you are using might not support DDP. "
-                    "If it doesnt, please use one of the following sampler: DistributedSampler, InfiniteSampler, RepeatAugSampler"
+                    "If it doesnt, please use one of the following sampler: DistributedSampler, RepeatAugSampler"
                 )
                 )
         self.training_params = TrainingParams()
         self.training_params = TrainingParams()
         self.training_params.override(**training_params)
         self.training_params.override(**training_params)
@@ -1164,10 +1159,6 @@ class Trainer:
 
 
         self._initialize_mixed_precision(self.training_params.mixed_precision)
         self._initialize_mixed_precision(self.training_params.mixed_precision)
 
 
-        self._infinite_train_loader = (hasattr(self.train_loader, "sampler") and isinstance(self.train_loader.sampler, InfiniteSampler)) or (
-            hasattr(self.train_loader, "batch_sampler") and isinstance(self.train_loader.batch_sampler.sampler, InfiniteSampler)
-        )
-
         self.ckpt_best_name = self.training_params.ckpt_best_name
         self.ckpt_best_name = self.training_params.ckpt_best_name
 
 
         if self.training_params.max_train_batches is not None:
         if self.training_params.max_train_batches is not None:
Discard
@@ -165,10 +165,9 @@ class DataLoaderFactoryTest(unittest.TestCase):
 
 
     def test_imagenet_resnet50_kd_train_creation(self):
     def test_imagenet_resnet50_kd_train_creation(self):
         # Here we need to overwrite the sampler because the RepeatAugSampler used in KD is only supported for DDP
         # Here we need to overwrite the sampler because the RepeatAugSampler used in KD is only supported for DDP
-        dl = imagenet_resnet50_kd_train(dataloader_params={"sampler": {"InfiniteSampler": {}}})
+        dl = imagenet_resnet50_kd_train(dataloader_params={"sampler": {"RandomSampler": {}}})
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl, DataLoader))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
         self.assertTrue(isinstance(dl.dataset, ImageNetDataset))
-        self.assertTrue(dl.sampler._shuffle)
 
 
     def test_imagenet_resnet50_kd_val_creation(self):
     def test_imagenet_resnet50_kd_val_creation(self):
         dl = imagenet_resnet50_kd_val()
         dl = imagenet_resnet50_kd_val()
Discard