Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#867 Fix trainer test

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-000-minor_trainer_test_fix
1 changed files with 11 additions and 9 deletions
  1. 11
    9
      tests/end_to_end_tests/trainer_test.py
@@ -10,6 +10,7 @@ import os
 from super_gradients import Trainer
 from super_gradients import Trainer
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
 from super_gradients.training.metrics import Accuracy, Top5
 from super_gradients.training.metrics import Accuracy, Top5
+from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path
 
 
 
 
 class TestTrainer(unittest.TestCase):
 class TestTrainer(unittest.TestCase):
@@ -17,7 +18,7 @@ class TestTrainer(unittest.TestCase):
     def setUp(cls):
     def setUp(cls):
         super_gradients.init_trainer()
         super_gradients.init_trainer()
         # NAMES FOR THE EXPERIMENTS TO LATER DELETE
         # NAMES FOR THE EXPERIMENTS TO LATER DELETE
-        cls.folder_names = ["test_train", "test_save_load", "test_load_w", "test_load_w2", "test_load_w3", "test_checkpoint_content", "analyze"]
+        cls.experiment_names = ["test_train", "test_save_load", "test_load_w", "test_load_w2", "test_load_w3", "test_checkpoint_content", "analyze"]
         cls.training_params = {
         cls.training_params = {
             "max_epochs": 1,
             "max_epochs": 1,
             "silent_mode": True,
             "silent_mode": True,
@@ -34,10 +35,11 @@ class TestTrainer(unittest.TestCase):
 
 
     @classmethod
     @classmethod
     def tearDownClass(cls) -> None:
     def tearDownClass(cls) -> None:
-        # ERASE ALL THE FOLDERS THAT WERE CREATED DURING THIS TEST
-        for folder in cls.folder_names:
-            if os.path.isdir(os.path.join("checkpoints", folder)):
-                shutil.rmtree(os.path.join("checkpoints", folder))
+        # ERASE ALL THE EXPERIMENT FOLDERS THAT WERE CREATED DURING THIS TEST
+        for experiment_name in cls.experiment_names:
+            experiment_dir = get_checkpoints_dir_path(experiment_name=experiment_name)
+            if os.path.isdir(experiment_dir):
+                shutil.rmtree(experiment_dir)
 
 
     @staticmethod
     @staticmethod
     def get_classification_trainer(name=""):
     def get_classification_trainer(name=""):
@@ -46,27 +48,27 @@ class TestTrainer(unittest.TestCase):
         return trainer, model
         return trainer, model
 
 
     def test_train(self):
     def test_train(self):
-        trainer, model = self.get_classification_trainer(self.folder_names[0])
+        trainer, model = self.get_classification_trainer(self.experiment_names[0])
         trainer.train(
         trainer.train(
             model=model, training_params=self.training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
             model=model, training_params=self.training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
         )
         )
 
 
     def test_save_load(self):
     def test_save_load(self):
-        trainer, model = self.get_classification_trainer(self.folder_names[1])
+        trainer, model = self.get_classification_trainer(self.experiment_names[1])
         trainer.train(
         trainer.train(
             model=model, training_params=self.training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
             model=model, training_params=self.training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
         )
         )
         resume_training_params = self.training_params.copy()
         resume_training_params = self.training_params.copy()
         resume_training_params["resume"] = True
         resume_training_params["resume"] = True
         resume_training_params["max_epochs"] = 2
         resume_training_params["max_epochs"] = 2
-        trainer, model = self.get_classification_trainer(self.folder_names[1])
+        trainer, model = self.get_classification_trainer(self.experiment_names[1])
         trainer.train(
         trainer.train(
             model=model, training_params=resume_training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
             model=model, training_params=resume_training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
         )
         )
 
 
     def test_checkpoint_content(self):
     def test_checkpoint_content(self):
         """VERIFY THAT ALL CHECKPOINTS ARE SAVED AND CONTAIN ALL THE EXPECTED KEYS"""
         """VERIFY THAT ALL CHECKPOINTS ARE SAVED AND CONTAIN ALL THE EXPECTED KEYS"""
-        trainer, model = self.get_classification_trainer(self.folder_names[5])
+        trainer, model = self.get_classification_trainer(self.experiment_names[5])
         params = self.training_params.copy()
         params = self.training_params.copy()
         params["save_ckpt_epoch_list"] = [1]
         params["save_ckpt_epoch_list"] = [1]
         trainer.train(model=model, training_params=params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
         trainer.train(model=model, training_params=params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
Discard