Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#396 Trainer constructor cleanup

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-266_clean_trainer_ctor
@@ -6,9 +6,9 @@ from super_gradients.training.dataloaders.dataloaders import classification_test
     detection_test_dataloader, segmentation_test_dataloader
     detection_test_dataloader, segmentation_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training import MultiGPUMode, models
 from super_gradients.training import MultiGPUMode, models
-from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.detection_metrics import DetectionMetrics
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
 from super_gradients.training.metrics.segmentation_metrics import PixelAccuracy, IoU
+from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
 
 
 
 
 class TestWithoutTrainTest(unittest.TestCase):
 class TestWithoutTrainTest(unittest.TestCase):
@@ -26,22 +26,21 @@ class TestWithoutTrainTest(unittest.TestCase):
 
 
     @staticmethod
     @staticmethod
     def get_classification_trainer(name=''):
     def get_classification_trainer(name=''):
-        trainer = Trainer(name, model_checkpoints_location='local')
+        trainer = Trainer(name)
         model = models.get("resnet18", num_classes=5)
         model = models.get("resnet18", num_classes=5)
         return trainer, model
         return trainer, model
 
 
     @staticmethod
     @staticmethod
     def get_detection_trainer(name=''):
     def get_detection_trainer(name=''):
-        trainer = Trainer(name, model_checkpoints_location='local',
-                          multi_gpu=MultiGPUMode.OFF,
-                          post_prediction_callback=YoloPostPredictionCallback())
+        trainer = Trainer(name,
+                          multi_gpu=MultiGPUMode.OFF)
         model = models.get("yolox_s", num_classes=5)
         model = models.get("yolox_s", num_classes=5)
         return trainer, model
         return trainer, model
 
 
     @staticmethod
     @staticmethod
     def get_segmentation_trainer(name=''):
     def get_segmentation_trainer(name=''):
         shelfnet_lw_arch_params = {"num_classes": 5, "load_checkpoint": False}
         shelfnet_lw_arch_params = {"num_classes": 5, "load_checkpoint": False}
-        trainer = Trainer(name, model_checkpoints_location='local', multi_gpu=False)
+        trainer = Trainer(name)
         model = models.get('shelfnet34_lw', arch_params=shelfnet_lw_arch_params)
         model = models.get('shelfnet34_lw', arch_params=shelfnet_lw_arch_params)
         return trainer, model
         return trainer, model
 
 
@@ -52,7 +51,7 @@ class TestWithoutTrainTest(unittest.TestCase):
 
 
         trainer, model = self.get_detection_trainer(self.folder_names[1])
         trainer, model = self.get_detection_trainer(self.folder_names[1])
 
 
-        test_metrics = [DetectionMetrics(post_prediction_callback=trainer.post_prediction_callback, num_cls=5)]
+        test_metrics = [DetectionMetrics(post_prediction_callback=YoloPostPredictionCallback(), num_cls=5)]
 
 
         assert isinstance(trainer.test(model=model, silent_mode=True,
         assert isinstance(trainer.test(model=model, silent_mode=True,
                                        test_metrics_list=test_metrics, test_loader=detection_test_dataloader(image_size=320)), tuple)
                                        test_metrics_list=test_metrics, test_loader=detection_test_dataloader(image_size=320)), tuple)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file