Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#367 fix: Request correct hydra-core

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:hotfix/ALG-000_hydra-req
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  1. import unittest
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from super_gradients.training.models.all_architectures import ARCHITECTURES
  6. from super_gradients.training.models.model_registry import register
  7. class ModelRegistryTest(unittest.TestCase):
  8. def setUp(self):
  9. @register('myconvnet')
  10. class MyConvNet(nn.Module):
  11. def __init__(self, num_classes):
  12. super().__init__()
  13. self.conv1 = nn.Conv2d(3, 6, 5)
  14. self.pool = nn.MaxPool2d(2, 2)
  15. self.conv2 = nn.Conv2d(6, 16, 5)
  16. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  17. self.fc2 = nn.Linear(120, 84)
  18. self.fc3 = nn.Linear(84, num_classes)
  19. def forward(self, x):
  20. x = self.pool(F.relu(self.conv1(x)))
  21. x = self.pool(F.relu(self.conv2(x)))
  22. x = torch.flatten(x, 1)
  23. x = F.relu(self.fc1(x))
  24. x = F.relu(self.fc2(x))
  25. x = self.fc3(x)
  26. return x
  27. @register()
  28. def myconvnet_for_cifar10():
  29. return MyConvNet(num_classes=10)
  30. def tearDown(self):
  31. ARCHITECTURES.pop('myconvnet', None)
  32. ARCHITECTURES.pop('myconvnet_for_cifar10', None)
  33. def test_cls_is_registered(self):
  34. assert ARCHITECTURES['myconvnet']
  35. def test_fn_is_registered(self):
  36. assert ARCHITECTURES['myconvnet_for_cifar10']
  37. def test_model_is_instantiable(self):
  38. assert ARCHITECTURES['myconvnet_for_cifar10']()
  39. assert ARCHITECTURES['myconvnet'](num_classes=10)
  40. def test_model_outputs(self):
  41. torch.manual_seed(0)
  42. model_1 = ARCHITECTURES['myconvnet_for_cifar10']()
  43. torch.manual_seed(0)
  44. model_2 = ARCHITECTURES['myconvnet'](num_classes=10)
  45. dummy_input = torch.randn(1, 3, 32, 32, requires_grad=False)
  46. x = model_1(dummy_input)
  47. y = model_2(dummy_input)
  48. assert torch.equal(x, y)
  49. def test_existing_key(self):
  50. with self.assertRaises(Exception):
  51. @register()
  52. def myconvnet_for_cifar10():
  53. return
  54. if __name__ == '__main__':
  55. unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file