Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#869 Add DagsHub Logger to Super Gradients

Merged
Ghost merged 1 commits into Deci-AI:master from timho102003:dagshub_logger
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
  1. import unittest
  2. import torch
  3. from super_gradients import Trainer
  4. from super_gradients.common.object_names import Models
  5. from super_gradients.training import models
  6. from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
  7. from super_gradients.training.metrics import Accuracy
  8. class CallTrainAfterTestTest(unittest.TestCase):
  9. """
  10. CallTrainTwiceTest
  11. Purpose is to call train after test and see nothing crashes. Should be ran with available GPUs (when possible)
  12. so when calling train again we see there's no change in the model's device.
  13. """
  14. def test_call_train_after_test(self):
  15. trainer = Trainer("test_call_train_after_test")
  16. dataloader = classification_test_dataloader(batch_size=10)
  17. model = models.get(Models.RESNET18, num_classes=5)
  18. train_params = {
  19. "max_epochs": 2,
  20. "lr_updates": [1],
  21. "lr_decay_factor": 0.1,
  22. "lr_mode": "step",
  23. "lr_warmup_epochs": 0,
  24. "initial_lr": 0.1,
  25. "loss": torch.nn.CrossEntropyLoss(),
  26. "optimizer": "SGD",
  27. "criterion_params": {},
  28. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  29. "train_metrics_list": [Accuracy()],
  30. "valid_metrics_list": [Accuracy()],
  31. "metric_to_watch": "Accuracy",
  32. "greater_metric_to_watch_is_better": True,
  33. }
  34. trainer.test(model=model, test_metrics_list=[Accuracy()], test_loader=dataloader)
  35. trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
  36. def test_call_train_after_test_with_loss(self):
  37. trainer = Trainer("test_call_train_after_test_with_loss")
  38. dataloader = classification_test_dataloader(batch_size=10)
  39. model = models.get(Models.RESNET18, num_classes=5)
  40. train_params = {
  41. "max_epochs": 2,
  42. "lr_updates": [1],
  43. "lr_decay_factor": 0.1,
  44. "lr_mode": "step",
  45. "lr_warmup_epochs": 0,
  46. "initial_lr": 0.1,
  47. "loss": torch.nn.CrossEntropyLoss(),
  48. "optimizer": "SGD",
  49. "criterion_params": {},
  50. "optimizer_params": {"weight_decay": 1e-4, "momentum": 0.9},
  51. "train_metrics_list": [Accuracy()],
  52. "valid_metrics_list": [Accuracy()],
  53. "metric_to_watch": "Accuracy",
  54. "greater_metric_to_watch_is_better": True,
  55. }
  56. trainer.test(model=model, test_metrics_list=[Accuracy()], test_loader=dataloader, loss=torch.nn.CrossEntropyLoss())
  57. trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
  58. if __name__ == "__main__":
  59. unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file