Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

registry_test.py 3.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
  1. import unittest
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import torchmetrics
  6. from torch.nn.modules.loss import _Loss
  7. from super_gradients.training.models.all_architectures import ARCHITECTURES
  8. from super_gradients.training.metrics.all_metrics import METRICS
  9. from super_gradients.training.losses.all_losses import LOSSES
  10. from super_gradients.common.registry import register_model, register_metric, register_loss
  11. class RegistryTest(unittest.TestCase):
  12. def setUp(self):
  13. @register_model('myconvnet')
  14. class MyConvNet(nn.Module):
  15. def __init__(self, num_classes):
  16. super().__init__()
  17. self.conv1 = nn.Conv2d(3, 6, 5)
  18. self.pool = nn.MaxPool2d(2, 2)
  19. self.conv2 = nn.Conv2d(6, 16, 5)
  20. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  21. self.fc2 = nn.Linear(120, 84)
  22. self.fc3 = nn.Linear(84, num_classes)
  23. def forward(self, x):
  24. x = self.pool(F.relu(self.conv1(x)))
  25. x = self.pool(F.relu(self.conv2(x)))
  26. x = torch.flatten(x, 1)
  27. x = F.relu(self.fc1(x))
  28. x = F.relu(self.fc2(x))
  29. x = self.fc3(x)
  30. return x
  31. @register_model()
  32. def myconvnet_for_cifar10():
  33. return MyConvNet(num_classes=10)
  34. @register_metric('custom_accuracy') # Will be registered as "custom_accuracy"
  35. class CustomAccuracy(torchmetrics.Accuracy):
  36. def update(self, preds: torch.Tensor, target: torch.Tensor):
  37. if target.shape == preds.shape:
  38. target = target.argmax(1) # Supports smooth labels
  39. super().update(preds=preds.argmax(1), target=target)
  40. @register_loss("custom_rsquared_loss")
  41. class CustomRSquaredLoss(_Loss):
  42. def forward(self, output, target):
  43. criterion_mse = nn.MSELoss()
  44. return 1 - criterion_mse(output, target).item() / torch.var(target).item()
  45. def tearDown(self):
  46. ARCHITECTURES.pop('myconvnet', None)
  47. ARCHITECTURES.pop('myconvnet_for_cifar10', None)
  48. METRICS.pop('custom_accuracy', None)
  49. LOSSES.pop('custom_rsquared_loss', None)
  50. def test_cls_is_registered(self):
  51. assert ARCHITECTURES['myconvnet']
  52. assert METRICS['custom_accuracy']
  53. assert LOSSES['custom_rsquared_loss']
  54. def test_fn_is_registered(self):
  55. assert ARCHITECTURES['myconvnet_for_cifar10']
  56. def test_is_instantiable(self):
  57. assert ARCHITECTURES['myconvnet_for_cifar10']()
  58. assert ARCHITECTURES['myconvnet'](num_classes=10)
  59. assert METRICS['custom_accuracy']()
  60. assert LOSSES['custom_rsquared_loss']()
  61. def test_model_outputs(self):
  62. torch.manual_seed(0)
  63. model_1 = ARCHITECTURES['myconvnet_for_cifar10']()
  64. torch.manual_seed(0)
  65. model_2 = ARCHITECTURES['myconvnet'](num_classes=10)
  66. dummy_input = torch.randn(1, 3, 32, 32, requires_grad=False)
  67. x = model_1(dummy_input)
  68. y = model_2(dummy_input)
  69. assert torch.equal(x, y)
  70. def test_existing_key(self):
  71. with self.assertRaises(Exception):
  72. @register_model()
  73. def myconvnet_for_cifar10():
  74. return
  75. if __name__ == '__main__':
  76. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...