|
@@ -10,6 +10,7 @@ import os
|
|
from super_gradients import Trainer
|
|
from super_gradients import Trainer
|
|
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
|
|
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
|
|
from super_gradients.training.metrics import Accuracy, Top5
|
|
from super_gradients.training.metrics import Accuracy, Top5
|
|
|
|
+from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path
|
|
|
|
|
|
|
|
|
|
class TestTrainer(unittest.TestCase):
|
|
class TestTrainer(unittest.TestCase):
|
|
@@ -17,7 +18,7 @@ class TestTrainer(unittest.TestCase):
|
|
def setUp(cls):
|
|
def setUp(cls):
|
|
super_gradients.init_trainer()
|
|
super_gradients.init_trainer()
|
|
# NAMES FOR THE EXPERIMENTS TO LATER DELETE
|
|
# NAMES FOR THE EXPERIMENTS TO LATER DELETE
|
|
- cls.folder_names = ["test_train", "test_save_load", "test_load_w", "test_load_w2", "test_load_w3", "test_checkpoint_content", "analyze"]
|
|
|
|
|
|
+ cls.experiment_names = ["test_train", "test_save_load", "test_load_w", "test_load_w2", "test_load_w3", "test_checkpoint_content", "analyze"]
|
|
cls.training_params = {
|
|
cls.training_params = {
|
|
"max_epochs": 1,
|
|
"max_epochs": 1,
|
|
"silent_mode": True,
|
|
"silent_mode": True,
|
|
@@ -34,10 +35,11 @@ class TestTrainer(unittest.TestCase):
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
def tearDownClass(cls) -> None:
|
|
def tearDownClass(cls) -> None:
|
|
- # ERASE ALL THE FOLDERS THAT WERE CREATED DURING THIS TEST
|
|
|
|
- for folder in cls.folder_names:
|
|
|
|
- if os.path.isdir(os.path.join("checkpoints", folder)):
|
|
|
|
- shutil.rmtree(os.path.join("checkpoints", folder))
|
|
|
|
|
|
+ # ERASE ALL THE EXPERIMENT FOLDERS THAT WERE CREATED DURING THIS TEST
|
|
|
|
+ for experiment_name in cls.experiment_names:
|
|
|
|
+ experiment_dir = get_checkpoints_dir_path(experiment_name=experiment_name)
|
|
|
|
+ if os.path.isdir(experiment_dir):
|
|
|
|
+ shutil.rmtree(experiment_dir)
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
def get_classification_trainer(name=""):
|
|
def get_classification_trainer(name=""):
|
|
@@ -46,27 +48,27 @@ class TestTrainer(unittest.TestCase):
|
|
return trainer, model
|
|
return trainer, model
|
|
|
|
|
|
def test_train(self):
|
|
def test_train(self):
|
|
- trainer, model = self.get_classification_trainer(self.folder_names[0])
|
|
|
|
|
|
+ trainer, model = self.get_classification_trainer(self.experiment_names[0])
|
|
trainer.train(
|
|
trainer.train(
|
|
model=model, training_params=self.training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
|
|
model=model, training_params=self.training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
|
|
)
|
|
)
|
|
|
|
|
|
def test_save_load(self):
|
|
def test_save_load(self):
|
|
- trainer, model = self.get_classification_trainer(self.folder_names[1])
|
|
|
|
|
|
+ trainer, model = self.get_classification_trainer(self.experiment_names[1])
|
|
trainer.train(
|
|
trainer.train(
|
|
model=model, training_params=self.training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
|
|
model=model, training_params=self.training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
|
|
)
|
|
)
|
|
resume_training_params = self.training_params.copy()
|
|
resume_training_params = self.training_params.copy()
|
|
resume_training_params["resume"] = True
|
|
resume_training_params["resume"] = True
|
|
resume_training_params["max_epochs"] = 2
|
|
resume_training_params["max_epochs"] = 2
|
|
- trainer, model = self.get_classification_trainer(self.folder_names[1])
|
|
|
|
|
|
+ trainer, model = self.get_classification_trainer(self.experiment_names[1])
|
|
trainer.train(
|
|
trainer.train(
|
|
model=model, training_params=resume_training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
|
|
model=model, training_params=resume_training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
|
|
)
|
|
)
|
|
|
|
|
|
def test_checkpoint_content(self):
|
|
def test_checkpoint_content(self):
|
|
"""VERIFY THAT ALL CHECKPOINTS ARE SAVED AND CONTAIN ALL THE EXPECTED KEYS"""
|
|
"""VERIFY THAT ALL CHECKPOINTS ARE SAVED AND CONTAIN ALL THE EXPECTED KEYS"""
|
|
- trainer, model = self.get_classification_trainer(self.folder_names[5])
|
|
|
|
|
|
+ trainer, model = self.get_classification_trainer(self.experiment_names[5])
|
|
params = self.training_params.copy()
|
|
params = self.training_params.copy()
|
|
params["save_ckpt_epoch_list"] = [1]
|
|
params["save_ckpt_epoch_list"] = [1]
|
|
trainer.train(model=model, training_params=params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
|
|
trainer.train(model=model, training_params=params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
|