Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

double_training_test.py 1.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
  1. import unittest
  2. import torch
  3. from super_gradients import Trainer
  4. from super_gradients.common.object_names import Models
  5. from super_gradients.training import models
  6. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  7. from super_gradients.training.metrics import Accuracy
  8. class CallTrainTwiceTest(unittest.TestCase):
  9. """
  10. CallTrainTwiceTest
  11. Purpose is to call train twice and see nothing crashes. Should be ran with available GPUs (when possible)
  12. so when calling train again we see there's no change in the model's device.
  13. """
  14. def test_call_train_twice(self):
  15. trainer = Trainer("external_criterion_test")
  16. dataloader = classification_test_dataloader(batch_size=10)
  17. model = models.get(Models.RESNET18, num_classes=5)
  18. train_params = {
  19. "max_epochs": 2,
  20. "lr_updates": [1],
  21. "lr_decay_factor": 0.1,
  22. "lr_mode": "step",
  23. "lr_warmup_epochs": 0,
  24. "initial_lr": 0.1,
  25. "loss": torch.nn.CrossEntropyLoss(),
  26. "optimizer": "SGD",
  27. "criterion_params": {},
  28. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  29. "train_metrics_list": [Accuracy()],
  30. "valid_metrics_list": [Accuracy()],
  31. "metric_to_watch": "Accuracy",
  32. "greater_metric_to_watch_is_better": True,
  33. }
  34. trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
  35. trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
  36. if __name__ == "__main__":
  37. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...