Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#479 Bugfix- calling train after test

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bug/SG-401_train_after_test_fix
@@ -146,12 +146,14 @@ class Trainer:
         self.enable_qat = False
         self.enable_qat = False
         self.qat_params = {}
         self.qat_params = {}
         self._infinite_train_loader = False
         self._infinite_train_loader = False
+        self._first_backward = True
 
 
         # METRICS
         # METRICS
         self.loss_logging_items_names = None
         self.loss_logging_items_names = None
         self.train_metrics = None
         self.train_metrics = None
         self.valid_metrics = None
         self.valid_metrics = None
         self.greater_metric_to_watch_is_better = None
         self.greater_metric_to_watch_is_better = None
+        self.metric_to_watch = None
 
 
         # SETTING THE PROPERTIES FROM THE CONSTRUCTOR
         # SETTING THE PROPERTIES FROM THE CONSTRUCTOR
         self.experiment_name = experiment_name
         self.experiment_name = experiment_name
@@ -426,7 +428,8 @@ class Trainer:
         # ON FIRST BACKWARD, DERRIVE THE LOGGING TITLES.
         # ON FIRST BACKWARD, DERRIVE THE LOGGING TITLES.
         if self.loss_logging_items_names is None or self._first_backward:
         if self.loss_logging_items_names is None or self._first_backward:
             self._init_loss_logging_names(loss_logging_items)
             self._init_loss_logging_names(loss_logging_items)
-            self._init_monitored_items()
+            if self.metric_to_watch:
+                self._init_monitored_items()
             self._first_backward = False
             self._first_backward = False
 
 
         if len(loss_logging_items) != len(self.loss_logging_items_names):
         if len(loss_logging_items) != len(self.loss_logging_items_names):
@@ -1654,6 +1657,8 @@ class Trainer:
         if use_ema_net and self.ema_model is not None:
         if use_ema_net and self.ema_model is not None:
             self.net = keep_model
             self.net = keep_model
 
 
+        self._first_backward = True
+
         return test_results
         return test_results
 
 
     def _validate_epoch(self, epoch: int, silent_mode: bool = False) -> tuple:
     def _validate_epoch(self, epoch: int, silent_mode: bool = False) -> tuple:
Discard
@@ -18,6 +18,7 @@ from tests.unit_tests import (
     TrainOptimizerParamsOverride,
     TrainOptimizerParamsOverride,
     CallTrainTwiceTest,
     CallTrainTwiceTest,
     ResumeTrainingTest,
     ResumeTrainingTest,
+    CallTrainAfterTestTest,
 )
 )
 from tests.end_to_end_tests import TestTrainer
 from tests.end_to_end_tests import TestTrainer
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
 from tests.unit_tests.phase_delegates_test import ContextMethodsTest
@@ -95,6 +96,7 @@ class CoreUnitTestSuiteRunner:
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TrainOptimizerParamsOverride))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TrainOptimizerParamsOverride))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(PPYoloETests))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(PPYoloETests))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ResumeTrainingTest))
         self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ResumeTrainingTest))
+        self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(CallTrainAfterTestTest))
 
 
     def _add_modules_to_end_to_end_tests_suite(self):
     def _add_modules_to_end_to_end_tests_suite(self):
         """
         """
Discard
@@ -4,6 +4,7 @@ from tests.unit_tests.factories_test import FactoriesTest
 from tests.unit_tests.optimizer_params_override_test import TrainOptimizerParamsOverride
 from tests.unit_tests.optimizer_params_override_test import TrainOptimizerParamsOverride
 from tests.unit_tests.resume_training_test import ResumeTrainingTest
 from tests.unit_tests.resume_training_test import ResumeTrainingTest
 from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
 from tests.unit_tests.strictload_enum_test import StrictLoadEnumTest
+from tests.unit_tests.train_after_test_test import CallTrainAfterTestTest
 from tests.unit_tests.zero_weight_decay_on_bias_bn_test import ZeroWdForBnBiasTest
 from tests.unit_tests.zero_weight_decay_on_bias_bn_test import ZeroWdForBnBiasTest
 from tests.unit_tests.save_ckpt_test import SaveCkptListUnitTest
 from tests.unit_tests.save_ckpt_test import SaveCkptListUnitTest
 from tests.unit_tests.all_architectures_test import AllArchitecturesTest
 from tests.unit_tests.all_architectures_test import AllArchitecturesTest
@@ -41,4 +42,5 @@ __all__ = [
     "TrainOptimizerParamsOverride",
     "TrainOptimizerParamsOverride",
     "CallTrainTwiceTest",
     "CallTrainTwiceTest",
     "ResumeTrainingTest",
     "ResumeTrainingTest",
+    "CallTrainAfterTestTest",
 ]
 ]
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
  1. import unittest
  2. import torch
  3. from super_gradients import Trainer
  4. from super_gradients.training import models
  5. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  6. from super_gradients.training.metrics import Accuracy
  7. class CallTrainAfterTestTest(unittest.TestCase):
  8. """
  9. CallTrainTwiceTest
  10. Purpose is to call train after test and see nothing crashes. Should be ran with available GPUs (when possible)
  11. so when calling train again we see there's no change in the model's device.
  12. """
  13. def test_call_train_after_test(self):
  14. trainer = Trainer("test_call_train_after_test")
  15. dataloader = classification_test_dataloader(batch_size=10)
  16. model = models.get("resnet18", num_classes=5)
  17. train_params = {
  18. "max_epochs": 2,
  19. "lr_updates": [1],
  20. "lr_decay_factor": 0.1,
  21. "lr_mode": "step",
  22. "lr_warmup_epochs": 0,
  23. "initial_lr": 0.1,
  24. "loss": torch.nn.CrossEntropyLoss(),
  25. "optimizer": "SGD",
  26. "criterion_params": {},
  27. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  28. "train_metrics_list": [Accuracy()],
  29. "valid_metrics_list": [Accuracy()],
  30. "metric_to_watch": "Accuracy",
  31. "greater_metric_to_watch_is_better": True,
  32. }
  33. trainer.test(model=model, test_metrics_list=[Accuracy()], test_loader=dataloader)
  34. trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
  35. def test_call_train_after_test_with_loss(self):
  36. trainer = Trainer("test_call_train_after_test_with_loss")
  37. dataloader = classification_test_dataloader(batch_size=10)
  38. model = models.get("resnet18", num_classes=5)
  39. train_params = {
  40. "max_epochs": 2,
  41. "lr_updates": [1],
  42. "lr_decay_factor": 0.1,
  43. "lr_mode": "step",
  44. "lr_warmup_epochs": 0,
  45. "initial_lr": 0.1,
  46. "loss": torch.nn.CrossEntropyLoss(),
  47. "optimizer": "SGD",
  48. "criterion_params": {},
  49. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  50. "train_metrics_list": [Accuracy()],
  51. "valid_metrics_list": [Accuracy()],
  52. "metric_to_watch": "Accuracy",
  53. "greater_metric_to_watch_is_better": True,
  54. }
  55. trainer.test(model=model, test_metrics_list=[Accuracy()], test_loader=dataloader, loss=torch.nn.CrossEntropyLoss())
  56. trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
  57. if __name__ == "__main__":
  58. unittest.main()
Discard