Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train_after_test_test.py 3.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  1. import unittest
  2. import torch
  3. from super_gradients import Trainer
  4. from super_gradients.common.object_names import Models
  5. from super_gradients.training import models
  6. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  7. from super_gradients.training.metrics import Accuracy
  8. class CallTrainAfterTestTest(unittest.TestCase):
  9. """
  10. CallTrainTwiceTest
  11. Purpose is to call train after test and see nothing crashes. Should be ran with available GPUs (when possible)
  12. so when calling train again we see there's no change in the model's device.
  13. """
  14. def setUp(self) -> None:
  15. self.train_params = {
  16. "max_epochs": 2,
  17. "lr_updates": [1],
  18. "lr_decay_factor": 0.1,
  19. "lr_mode": "StepLRScheduler",
  20. "lr_warmup_epochs": 0,
  21. "initial_lr": 0.1,
  22. "loss": torch.nn.CrossEntropyLoss(),
  23. "optimizer": "SGD",
  24. "criterion_params": {},
  25. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  26. "train_metrics_list": [Accuracy()],
  27. "valid_metrics_list": [Accuracy()],
  28. "metric_to_watch": "Accuracy",
  29. "greater_metric_to_watch_is_better": True,
  30. }
  31. def test_call_train_after_test(self):
  32. trainer = Trainer("test_call_train_after_test")
  33. dataloader = classification_test_dataloader(batch_size=10)
  34. model = models.get(Models.RESNET18, num_classes=5)
  35. trainer.test(model=model, test_metrics_list=[Accuracy()], test_loader=dataloader)
  36. trainer.train(model=model, training_params=self.train_params, train_loader=dataloader, valid_loader=dataloader)
  37. def test_call_train_after_test_with_loss(self):
  38. trainer = Trainer("test_call_train_after_test_with_loss")
  39. dataloader = classification_test_dataloader(batch_size=10)
  40. model = models.get(Models.RESNET18, num_classes=5)
  41. trainer.test(model=model, test_metrics_list=[Accuracy()], test_loader=dataloader, loss=torch.nn.CrossEntropyLoss())
  42. trainer.train(model=model, training_params=self.train_params, train_loader=dataloader, valid_loader=dataloader)
  43. def test_training_with_testset_after_test(self):
  44. trainer = Trainer("training_with_testset_after_test")
  45. dataloader = classification_test_dataloader(batch_size=10)
  46. model = models.get(Models.RESNET18, num_classes=5)
  47. trainer.test(model=model, test_metrics_list=[Accuracy()], test_loader=dataloader)
  48. trainer.train(
  49. model=model,
  50. training_params=self.train_params,
  51. train_loader=dataloader,
  52. valid_loader=dataloader,
  53. test_loaders={"test1": dataloader, "test2": dataloader},
  54. )
  55. def test_test_after_training_with_testset(self):
  56. trainer = Trainer("test_after_training_with_testset")
  57. dataloader = classification_test_dataloader(batch_size=10)
  58. model = models.get(Models.RESNET18, num_classes=5)
  59. trainer.train(
  60. model=model,
  61. training_params=self.train_params,
  62. train_loader=dataloader,
  63. valid_loader=dataloader,
  64. test_loaders={"test1": dataloader, "test2": dataloader},
  65. )
  66. trainer.test(model=model, test_metrics_list=[Accuracy()], test_loader=dataloader, loss=torch.nn.CrossEntropyLoss())
  67. if __name__ == "__main__":
  68. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...