|
@@ -32,7 +32,7 @@ from super_gradients.common.factories.losses_factory import LossesFactory
|
|
from super_gradients.common.factories.metrics_factory import MetricsFactory
|
|
from super_gradients.common.factories.metrics_factory import MetricsFactory
|
|
|
|
|
|
from super_gradients.training import utils as core_utils, models, dataloaders
|
|
from super_gradients.training import utils as core_utils, models, dataloaders
|
|
-from super_gradients.training.datasets.samplers import InfiniteSampler, RepeatAugSampler
|
|
|
|
|
|
+from super_gradients.training.datasets.samplers import RepeatAugSampler
|
|
from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat
|
|
from super_gradients.training.exceptions.sg_trainer_exceptions import UnsupportedOptimizerFormat
|
|
from super_gradients.training.metrics.metric_utils import (
|
|
from super_gradients.training.metrics.metric_utils import (
|
|
get_metrics_titles,
|
|
get_metrics_titles,
|
|
@@ -160,7 +160,6 @@ class Trainer:
|
|
self.strict_load = StrictLoad.ON
|
|
self.strict_load = StrictLoad.ON
|
|
self.load_ema_as_net = False
|
|
self.load_ema_as_net = False
|
|
self.ckpt_best_name = "ckpt_best.pth"
|
|
self.ckpt_best_name = "ckpt_best.pth"
|
|
- self._infinite_train_loader = False
|
|
|
|
self._first_backward = True
|
|
self._first_backward = True
|
|
|
|
|
|
# METRICS
|
|
# METRICS
|
|
@@ -461,11 +460,7 @@ class Trainer:
|
|
progress_bar_train_loader.set_postfix(**pbar_message_dict)
|
|
progress_bar_train_loader.set_postfix(**pbar_message_dict)
|
|
self.phase_callback_handler.on_train_batch_end(context)
|
|
self.phase_callback_handler.on_train_batch_end(context)
|
|
|
|
|
|
- # TODO: ITERATE BY MAX ITERS
|
|
|
|
- # FOR INFINITE SAMPLERS WE MUST BREAK WHEN REACHING LEN ITERATIONS.
|
|
|
|
- if (self._infinite_train_loader and batch_idx == len(self.train_loader) - 1) or (
|
|
|
|
- self.max_train_batches is not None and self.max_train_batches - 1 <= batch_idx
|
|
|
|
- ):
|
|
|
|
|
|
+ if self.max_train_batches is not None and self.max_train_batches - 1 <= batch_idx:
|
|
break
|
|
break
|
|
|
|
|
|
self.train_monitored_values = sg_trainer_utils.update_monitored_values_dict(
|
|
self.train_monitored_values = sg_trainer_utils.update_monitored_values_dict(
|
|
@@ -1022,10 +1017,10 @@ class Trainer:
|
|
"You are using a SequentialSampler on you training dataloader, while working on DDP. "
|
|
"You are using a SequentialSampler on you training dataloader, while working on DDP. "
|
|
"This cancels the DDP benefits since it makes each process iterate through the entire dataset"
|
|
"This cancels the DDP benefits since it makes each process iterate through the entire dataset"
|
|
)
|
|
)
|
|
- if not isinstance(train_sampler, (DistributedSampler, InfiniteSampler, RepeatAugSampler)):
|
|
|
|
|
|
+ if not isinstance(train_sampler, (DistributedSampler, RepeatAugSampler)):
|
|
logger.warning(
|
|
logger.warning(
|
|
"The training sampler you are using might not support DDP. "
|
|
"The training sampler you are using might not support DDP. "
|
|
- "If it doesnt, please use one of the following sampler: DistributedSampler, InfiniteSampler, RepeatAugSampler"
|
|
|
|
|
|
+ "If it doesnt, please use one of the following sampler: DistributedSampler, RepeatAugSampler"
|
|
)
|
|
)
|
|
self.training_params = TrainingParams()
|
|
self.training_params = TrainingParams()
|
|
self.training_params.override(**training_params)
|
|
self.training_params.override(**training_params)
|
|
@@ -1164,10 +1159,6 @@ class Trainer:
|
|
|
|
|
|
self._initialize_mixed_precision(self.training_params.mixed_precision)
|
|
self._initialize_mixed_precision(self.training_params.mixed_precision)
|
|
|
|
|
|
- self._infinite_train_loader = (hasattr(self.train_loader, "sampler") and isinstance(self.train_loader.sampler, InfiniteSampler)) or (
|
|
|
|
- hasattr(self.train_loader, "batch_sampler") and isinstance(self.train_loader.batch_sampler.sampler, InfiniteSampler)
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
self.ckpt_best_name = self.training_params.ckpt_best_name
|
|
self.ckpt_best_name = self.training_params.ckpt_best_name
|
|
|
|
|
|
if self.training_params.max_train_batches is not None:
|
|
if self.training_params.max_train_batches is not None:
|