Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#396 Trainer constructor cleanup

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-266_clean_trainer_ctor
@@ -18,9 +18,7 @@ class TestDatasetStatisticsTensorboardLogger(unittest.TestCase):
         """
         """
         # Create dataset
         # Create dataset
 
 
-        trainer = Trainer('dataset_statistics_visual_test',
-                          model_checkpoints_location='local',
-                          post_prediction_callback=YoloPostPredictionCallback())
+        trainer = Trainer('dataset_statistics_visual_test')
 
 
         model = models.get("yolox_s")
         model = models.get("yolox_s")
 
 
Discard
@@ -12,10 +12,9 @@ class TestDetectionUtils(unittest.TestCase):
     def test_visualization(self):
     def test_visualization(self):
 
 
         # Create Yolo model
         # Create Yolo model
-        trainer = Trainer('visualization_test',
-                          model_checkpoints_location='local',
-                          post_prediction_callback=YoloPostPredictionCallback())
+        trainer = Trainer('visualization_test')
         model = models.get("yolox_n", pretrained_weights="coco")
         model = models.get("yolox_n", pretrained_weights="coco")
+        post_prediction_callback = YoloPostPredictionCallback()
 
 
         # Simulate one iteration of validation subset
         # Simulate one iteration of validation subset
         valid_loader = coco2017_val()
         valid_loader = coco2017_val()
@@ -23,7 +22,7 @@ class TestDetectionUtils(unittest.TestCase):
         imgs = core_utils.tensor_container_to_device(imgs, trainer.device)
         imgs = core_utils.tensor_container_to_device(imgs, trainer.device)
         targets = core_utils.tensor_container_to_device(targets, trainer.device)
         targets = core_utils.tensor_container_to_device(targets, trainer.device)
         output = model(imgs)
         output = model(imgs)
-        output = trainer.post_prediction_callback(output)
+        output = post_prediction_callback(output)
         # Visualize the batch
         # Visualize the batch
         DetectionVisualization.visualize_batch(imgs, output, targets, batch_i,
         DetectionVisualization.visualize_batch(imgs, output, targets, batch_i,
                                                COCO_DETECTION_CLASSES_LIST, trainer.checkpoints_dir_path)
                                                COCO_DETECTION_CLASSES_LIST, trainer.checkpoints_dir_path)
Discard
@@ -58,7 +58,7 @@ class EarlyStopTest(unittest.TestCase):
         Test for mode=min metric, test that training stops after no improvement in metric value for amount of `patience`
         Test for mode=min metric, test that training stops after no improvement in metric value for amount of `patience`
         epochs.
         epochs.
         """
         """
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", patience=3, verbose=True)
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", patience=3, verbose=True)
         phase_callbacks = [early_stop_loss]
         phase_callbacks = [early_stop_loss]
@@ -80,7 +80,7 @@ class EarlyStopTest(unittest.TestCase):
         Test for mode=max metric, test that training stops after no improvement in metric value for amount of `patience`
         Test for mode=max metric, test that training stops after no improvement in metric value for amount of `patience`
         epochs.
         epochs.
         """
         """
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=3,
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=3,
                                    verbose=True)
                                    verbose=True)
         phase_callbacks = [early_stop_acc]
         phase_callbacks = [early_stop_acc]
@@ -101,7 +101,7 @@ class EarlyStopTest(unittest.TestCase):
         """
         """
         Test for mode=min metric, test that training stops after metric value reaches the `threshold` value.
         Test for mode=min metric, test that training stops after metric value reaches the `threshold` value.
         """
         """
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", threshold=0.1, verbose=True)
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", threshold=0.1, verbose=True)
         phase_callbacks = [early_stop_loss]
         phase_callbacks = [early_stop_loss]
@@ -121,7 +121,7 @@ class EarlyStopTest(unittest.TestCase):
         """
         """
         Test for mode=max metric, test that training stops after metric value reaches the `threshold` value.
         Test for mode=max metric, test that training stops after metric value reaches the `threshold` value.
         """
         """
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
 
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", threshold=0.94,
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", threshold=0.94,
                                    verbose=True)
                                    verbose=True)
@@ -144,7 +144,7 @@ class EarlyStopTest(unittest.TestCase):
         Test that training stops when monitor value is not a finite number. Test case of NaN and Inf values.
         Test that training stops when monitor value is not a finite number. Test case of NaN and Inf values.
         """
         """
         # test Nan value
         # test Nan value
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", check_finite=True,
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", check_finite=True,
                                     verbose=True)
                                     verbose=True)
@@ -162,7 +162,7 @@ class EarlyStopTest(unittest.TestCase):
         self.assertEqual(excepted_end_epoch, fake_loss.count // 2)
         self.assertEqual(excepted_end_epoch, fake_loss.count // 2)
 
 
         # test Inf value
         # test Inf value
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
 
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", patience=3, verbose=True)
         early_stop_loss = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="LossTest", mode="min", patience=3, verbose=True)
         phase_callbacks = [early_stop_loss]
         phase_callbacks = [early_stop_loss]
@@ -183,7 +183,7 @@ class EarlyStopTest(unittest.TestCase):
         Test for `min_delta` argument, metric value is considered an improvement only if
         Test for `min_delta` argument, metric value is considered an improvement only if
         current_value - min_delta > best_value
         current_value - min_delta > best_value
         """
         """
-        trainer = Trainer("early_stop_test", model_checkpoints_location='local')
+        trainer = Trainer("early_stop_test")
 
 
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=2,
         early_stop_acc = EarlyStop(Phase.VALIDATION_EPOCH_END, monitor="MetricTest", mode="max", patience=2,
                                    min_delta=0.1, verbose=True)
                                    min_delta=0.1, verbose=True)
Discard
@@ -11,7 +11,7 @@ from super_gradients.training.metrics import Accuracy, Top5
 class FactoriesTest(unittest.TestCase):
 class FactoriesTest(unittest.TestCase):
 
 
     def test_training_with_factories(self):
     def test_training_with_factories(self):
-        trainer = Trainer("test_train_with_factories", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_factories")
         net = models.get("resnet18", num_classes=5)
         net = models.get("resnet18", num_classes=5)
         train_params = {"max_epochs": 2,
         train_params = {"max_epochs": 2,
                         "lr_updates": [1],
                         "lr_updates": [1],
Discard
@@ -6,7 +6,6 @@ from super_gradients import Trainer
 import torch
 import torch
 from torch.utils.data import TensorDataset, DataLoader
 from torch.utils.data import TensorDataset, DataLoader
 from super_gradients.training.metrics import Accuracy
 from super_gradients.training.metrics import Accuracy
-from super_gradients.training.exceptions.sg_trainer_exceptions import IllegalDataloaderInitialization
 
 
 
 
 class InitializeWithDataloadersTest(unittest.TestCase):
 class InitializeWithDataloadersTest(unittest.TestCase):
@@ -26,27 +25,8 @@ class InitializeWithDataloadersTest(unittest.TestCase):
         label = torch.randint(0, len(self.testcase_classes), size=(test_size,))
         label = torch.randint(0, len(self.testcase_classes), size=(test_size,))
         self.testcase_testloader = DataLoader(TensorDataset(inp, label))
         self.testcase_testloader = DataLoader(TensorDataset(inp, label))
 
 
-    def test_initialization_rules(self):
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          train_loader=self.testcase_trainloader)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          valid_loader=self.testcase_validloader)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          classes=self.testcase_classes)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          train_loader=self.testcase_trainloader, valid_loader=self.testcase_validloader)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          train_loader=self.testcase_trainloader, classes=self.testcase_classes)
-        self.assertRaises(IllegalDataloaderInitialization, Trainer, "test_name", model_checkpoints_location='local',
-                          valid_loader=self.testcase_validloader, classes=self.testcase_classes)
-        Trainer("test_name", model_checkpoints_location='local', train_loader=self.testcase_trainloader,
-                valid_loader=self.testcase_validloader, classes=self.testcase_classes)
-        Trainer("test_name", model_checkpoints_location='local', train_loader=self.testcase_trainloader,
-                valid_loader=self.testcase_validloader, test_loader=self.testcase_testloader,
-                classes=self.testcase_classes)
-
     def test_train_with_dataloaders(self):
     def test_train_with_dataloaders(self):
-        trainer = Trainer(experiment_name="test_name", model_checkpoints_location="local")
+        trainer = Trainer(experiment_name="test_name")
         model = models.get("resnet18", num_classes=5)
         model = models.get("resnet18", num_classes=5)
         trainer.train(model=model,
         trainer.train(model=model,
                       training_params={"max_epochs": 2,
                       training_params={"max_epochs": 2,
Discard
@@ -29,7 +29,7 @@ class LoadCheckpointWithEmaTest(unittest.TestCase):
     def test_ema_ckpt_reload(self):
     def test_ema_ckpt_reload(self):
         # Define Model
         # Define Model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
+        trainer = Trainer("ema_ckpt_test")
         trainer.train(model=net, training_params=self.train_params,
         trainer.train(model=net, training_params=self.train_params,
                       train_loader=classification_test_dataloader(),
                       train_loader=classification_test_dataloader(),
                       valid_loader=classification_test_dataloader())
                       valid_loader=classification_test_dataloader())
@@ -38,7 +38,7 @@ class LoadCheckpointWithEmaTest(unittest.TestCase):
 
 
         # TRAIN FOR 1 MORE EPOCH AND COMPARE THE NET AT THE BEGINNING OF EPOCH 3 AND THE END OF EPOCH NUMBER 2
         # TRAIN FOR 1 MORE EPOCH AND COMPARE THE NET AT THE BEGINNING OF EPOCH 3 AND THE END OF EPOCH NUMBER 2
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("ema_ckpt_test", model_checkpoints_location='local')
+        trainer = Trainer("ema_ckpt_test")
 
 
         net_collector = PreTrainingEMANetCollector()
         net_collector = PreTrainingEMANetCollector()
         self.train_params["resume"] = True
         self.train_params["resume"] = True
Discard
@@ -10,7 +10,7 @@ class LRCooldownTest(unittest.TestCase):
     def test_lr_cooldown_with_lr_scheduling(self):
     def test_lr_cooldown_with_lr_scheduling(self):
         # Define Model
         # Define Model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("lr_warmup_test")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
Discard
@@ -38,7 +38,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_lr_warmup(self):
     def test_lr_warmup(self):
         # Define Model
         # Define Model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("lr_warmup_test")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -60,7 +60,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_lr_warmup_with_lr_scheduling(self):
     def test_lr_warmup_with_lr_scheduling(self):
         # Define model
         # Define model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("lr_warmup_test")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -85,7 +85,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_warmup_initial_lr(self):
     def test_warmup_initial_lr(self):
         # Define model
         # Define model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("test_warmup_initial_lr", model_checkpoints_location='local')
+        trainer = Trainer("test_warmup_initial_lr")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
@@ -107,7 +107,7 @@ class LRWarmupTest(unittest.TestCase):
     def test_custom_lr_warmup(self):
     def test_custom_lr_warmup(self):
         # Define model
         # Define model
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("custom_lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("custom_lr_warmup_test")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
Discard
@@ -12,7 +12,7 @@ from torchmetrics import MetricCollection
 
 
 class PhaseContextTest(unittest.TestCase):
 class PhaseContextTest(unittest.TestCase):
     def context_information_in_train_test(self):
     def context_information_in_train_test(self):
-        trainer = Trainer("context_information_in_train_test", model_checkpoints_location='local')
+        trainer = Trainer("context_information_in_train_test")
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
 
 
Discard
@@ -31,7 +31,7 @@ class ContextMethodsCheckerCallback(PhaseCallback):
 class ContextMethodsTest(unittest.TestCase):
 class ContextMethodsTest(unittest.TestCase):
     def test_access_to_methods_by_phase(self):
     def test_access_to_methods_by_phase(self):
         net = LeNet()
         net = LeNet()
-        trainer = Trainer("test_access_to_methods_by_phase", model_checkpoints_location='local')
+        trainer = Trainer("test_access_to_methods_by_phase")
 
 
         phase_callbacks = []
         phase_callbacks = []
         for phase in Phase:
         for phase in Phase:
Discard
@@ -14,21 +14,21 @@ class PretrainedModelsUnitTest(unittest.TestCase):
         self.imagenet_pretrained_models = ["resnet50", "repvgg_a0", "regnetY800"]
         self.imagenet_pretrained_models = ["resnet50", "repvgg_a0", "regnetY800"]
 
 
     def test_pretrained_resnet50_imagenet(self):
     def test_pretrained_resnet50_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_resnet50_unit_test', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_resnet50_unit_test',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("resnet50", pretrained_weights="imagenet")
         model = models.get("resnet50", pretrained_weights="imagenet")
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()],
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()],
                      metrics_progress_verbose=True)
                      metrics_progress_verbose=True)
 
 
     def test_pretrained_regnetY800_imagenet(self):
     def test_pretrained_regnetY800_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_regnetY800_unit_test', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_regnetY800_unit_test',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("regnetY800", pretrained_weights="imagenet")
         model = models.get("regnetY800", pretrained_weights="imagenet")
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()],
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()],
                      metrics_progress_verbose=True)
                      metrics_progress_verbose=True)
 
 
     def test_pretrained_repvgg_a0_imagenet(self):
     def test_pretrained_repvgg_a0_imagenet(self):
-        trainer = Trainer('imagenet_pretrained_repvgg_a0_unit_test', model_checkpoints_location='local',
+        trainer = Trainer('imagenet_pretrained_repvgg_a0_unit_test',
                           multi_gpu=MultiGPUMode.OFF)
                           multi_gpu=MultiGPUMode.OFF)
         model = models.get("repvgg_a0", pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
         model = models.get("repvgg_a0", pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()],
         trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()],
Discard
@@ -18,7 +18,7 @@ class SaveCkptListUnitTest(unittest.TestCase):
                         "greater_metric_to_watch_is_better": True}
                         "greater_metric_to_watch_is_better": True}
 
 
         # Define Model
         # Define Model
-        trainer = Trainer("save_ckpt_test", model_checkpoints_location='local')
+        trainer = Trainer("save_ckpt_test")
 
 
         # Build Model
         # Build Model
         model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
         model = models.get("resnet18_cifar", arch_params={"num_classes": 10})
Discard
@@ -53,7 +53,7 @@ class StrictLoadEnumTest(unittest.TestCase):
         torch.save(cls.change_state_dict_keys(cls.original_torch_model.state_dict()), cls.checkpoint_diff_keys_path)
         torch.save(cls.change_state_dict_keys(cls.original_torch_model.state_dict()), cls.checkpoint_diff_keys_path)
 
 
         # Save the model's state_dict checkpoint in Trainer format
         # Save the model's state_dict checkpoint in Trainer format
-        cls.trainer = Trainer("load_checkpoint_test", model_checkpoints_location='local')  # Saves in /checkpoints
+        cls.trainer = Trainer("load_checkpoint_test")  # Saves in /checkpoints
         cls.trainer.set_net(cls.original_torch_model)
         cls.trainer.set_net(cls.original_torch_model)
         # FIXME: after uniting init and build_model we should remove this
         # FIXME: after uniting init and build_model we should remove this
         cls.trainer.sg_logger = BaseSGLogger('project_name', 'load_checkpoint_test', 'local', resumed=False,
         cls.trainer.sg_logger = BaseSGLogger('project_name', 'load_checkpoint_test', 'local', resumed=False,
Discard
@@ -6,9 +6,9 @@ from super_gradients.training.dataloaders.dataloaders import classification_test
     detection_test_dataloader, segmentation_test_dataloader
     detection_test_dataloader, segmentation_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training import MultiGPUMode, models
 from super_gradients.training import MultiGPUMode, models
-from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
+from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 
 
 
 
 class TestWithoutTrainTest(unittest.TestCase):
 class TestWithoutTrainTest(unittest.TestCase):
@@ -26,22 +26,21 @@ class TestWithoutTrainTest(unittest.TestCase):
 
 
     @staticmethod
     @staticmethod
     def get_classification_trainer(name=''):
     def get_classification_trainer(name=''):
-        trainer = Trainer(name, model_checkpoints_location='local')
+        trainer = Trainer(name)
         model = models.get("resnet18", num_classes=5)
         model = models.get("resnet18", num_classes=5)
         return trainer, model
         return trainer, model
 
 
     @staticmethod
     @staticmethod
     def get_detection_trainer(name=''):
     def get_detection_trainer(name=''):
-        trainer = Trainer(name, model_checkpoints_location='local',
-                          multi_gpu=MultiGPUMode.OFF,
-                          post_prediction_callback=YoloPostPredictionCallback())
+        trainer = Trainer(name,
+                          multi_gpu=MultiGPUMode.OFF)
         model = models.get("yolox_s", num_classes=5)
         model = models.get("yolox_s", num_classes=5)
         return trainer, model
         return trainer, model
 
 
     @staticmethod
     @staticmethod
     def get_segmentation_trainer(name=''):
     def get_segmentation_trainer(name=''):
         shelfnet_lw_arch_params = {"num_classes": 5, "load_checkpoint": False}
         shelfnet_lw_arch_params = {"num_classes": 5, "load_checkpoint": False}
-        trainer = Trainer(name, model_checkpoints_location='local', multi_gpu=False)
+        trainer = Trainer(name)
         model = models.get('shelfnet34_lw', arch_params=shelfnet_lw_arch_params)
         model = models.get('shelfnet34_lw', arch_params=shelfnet_lw_arch_params)
         return trainer, model
         return trainer, model
 
 
@@ -52,7 +51,7 @@ class TestWithoutTrainTest(unittest.TestCase):
 
 
         trainer, model = self.get_detection_trainer(self.folder_names[1])
         trainer, model = self.get_detection_trainer(self.folder_names[1])
 
 
-        test_metrics = [DetectionMetrics(post_prediction_callback=trainer.post_prediction_callback, num_cls=5)]
+        test_metrics = [DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=5)]
 
 
         assert isinstance(trainer.test(model=model, silent_mode=True,
         assert isinstance(trainer.test(model=model, silent_mode=True,
                                        test_metrics_list=test_metrics, test_loader=detection_test_dataloader(image_size=320)), tuple)
                                        test_metrics_list=test_metrics, test_loader=detection_test_dataloader(image_size=320)), tuple)
Discard
@@ -11,7 +11,7 @@ import shutil
 
 
 class SgTrainerLoggingTest(unittest.TestCase):
 class SgTrainerLoggingTest(unittest.TestCase):
     def test_train_logging(self):
     def test_train_logging(self):
-        trainer = Trainer("test_train_with_full_log", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_full_log")
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
Discard
@@ -19,7 +19,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
     """
     """
 
 
     def test_train_with_external_criterion(self):
     def test_train_with_external_criterion(self):
-        trainer = Trainer("external_criterion_test", model_checkpoints_location='local')
+        trainer = Trainer("external_criterion_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         model = models.get("resnet18", arch_params={"num_classes": 5})
         model = models.get("resnet18", arch_params={"num_classes": 5})
@@ -33,7 +33,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
 
 
     def test_train_with_external_optimizer(self):
     def test_train_with_external_optimizer(self):
-        trainer = Trainer("external_optimizer_test", model_checkpoints_location='local')
+        trainer = Trainer("external_optimizer_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         model = models.get("resnet18", arch_params={"num_classes": 5})
         model = models.get("resnet18", arch_params={"num_classes": 5})
@@ -47,7 +47,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
 
 
     def test_train_with_external_scheduler(self):
     def test_train_with_external_scheduler(self):
-        trainer = Trainer("external_scheduler_test", model_checkpoints_location='local')
+        trainer = Trainer("external_scheduler_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         lr = 0.3
         lr = 0.3
@@ -66,7 +66,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         self.assertTrue(lr_scheduler.get_last_lr()[0] == lr * 0.1 * 0.1)
         self.assertTrue(lr_scheduler.get_last_lr()[0] == lr * 0.1 * 0.1)
 
 
     def test_train_with_external_scheduler_class(self):
     def test_train_with_external_scheduler_class(self):
-        trainer = Trainer("external_scheduler_test", model_checkpoints_location='local')
+        trainer = Trainer("external_scheduler_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         model = models.get("resnet18", arch_params={"num_classes": 5})
         model = models.get("resnet18", arch_params={"num_classes": 5})
@@ -81,7 +81,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
 
 
     def test_train_with_reduce_on_plateau(self):
     def test_train_with_reduce_on_plateau(self):
-        trainer = Trainer("external_reduce_on_plateau_scheduler_test", model_checkpoints_location='local')
+        trainer = Trainer("external_reduce_on_plateau_scheduler_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         lr = 0.3
         lr = 0.3
@@ -101,7 +101,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         self.assertTrue(lr_scheduler._last_lr[0] == lr * 0.1)
         self.assertTrue(lr_scheduler._last_lr[0] == lr * 0.1)
 
 
     def test_train_with_external_metric(self):
     def test_train_with_external_metric(self):
-        trainer = Trainer("external_metric_test", model_checkpoints_location='local')
+        trainer = Trainer("external_metric_test")
         dataloader = classification_test_dataloader(batch_size=10)
         dataloader = classification_test_dataloader(batch_size=10)
 
 
         model = models.get("resnet18", arch_params={"num_classes": 5})
         model = models.get("resnet18", arch_params={"num_classes": 5})
@@ -114,7 +114,7 @@ class TrainWithInitializedObjectsTest(unittest.TestCase):
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
         trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
 
 
     def test_train_with_external_dataloaders(self):
     def test_train_with_external_dataloaders(self):
-        trainer = Trainer("external_data_loader_test", model_checkpoints_location='local')
+        trainer = Trainer("external_data_loader_test")
 
 
         batch_size = 5
         batch_size = 5
         trainset = torch.utils.data.TensorDataset(torch.Tensor(np.random.random((10, 3, 32, 32))),
         trainset = torch.utils.data.TensorDataset(torch.Tensor(np.random.random((10, 3, 32, 32))),
Discard
@@ -12,7 +12,7 @@ class TrainWithPreciseBNTest(unittest.TestCase):
     """
     """
 
 
     def test_train_with_precise_bn_explicit_size(self):
     def test_train_with_precise_bn_explicit_size(self):
-        trainer = Trainer("test_train_with_precise_bn_explicit_size", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_precise_bn_explicit_size")
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
                         "lr_warmup_epochs": 0, "initial_lr": 0.1, "loss": "cross_entropy", "optimizer": "SGD",
@@ -26,7 +26,7 @@ class TrainWithPreciseBNTest(unittest.TestCase):
                       valid_loader=classification_test_dataloader(batch_size=10))
                       valid_loader=classification_test_dataloader(batch_size=10))
 
 
     def test_train_with_precise_bn_implicit_size(self):
     def test_train_with_precise_bn_implicit_size(self):
-        trainer = Trainer("test_train_with_precise_bn_implicit_size", model_checkpoints_location='local')
+        trainer = Trainer("test_train_with_precise_bn_implicit_size")
 
 
         net = ResNet18(num_classes=5, arch_params={})
         net = ResNet18(num_classes=5, arch_params={})
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
         train_params = {"max_epochs": 2, "lr_updates": [1], "lr_decay_factor": 0.1, "lr_mode": "step",
Discard
@@ -29,7 +29,7 @@ class UpdateParamGroupsTest(unittest.TestCase):
     def test_lr_scheduling_with_update_param_groups(self):
     def test_lr_scheduling_with_update_param_groups(self):
         # Define Model
         # Define Model
         net = TestNet()
         net = TestNet()
-        trainer = Trainer("lr_warmup_test", model_checkpoints_location='local')
+        trainer = Trainer("lr_warmup_test")
 
 
         lrs = []
         lrs = []
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
         phase_callbacks = [TestLRCallback(lr_placeholder=lrs)]
Discard