Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#326 Feature/SG-245 Support for register model in model's factory

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:Feature/model_registry
@@ -17,6 +17,6 @@ from super_gradients.training.models.classification_models.vgg import *
 from super_gradients.training.models.classification_models.vit import *
 from super_gradients.training.models.classification_models.vit import *
 from super_gradients.training.models.segmentation_models.shelfnet import *
 from super_gradients.training.models.segmentation_models.shelfnet import *
 from super_gradients.training.models.classification_models.efficientnet import *
 from super_gradients.training.models.classification_models.efficientnet import *
-
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.models.all_architectures import ARCHITECTURES
-from super_gradients.training.models.model_factory import get
+from super_gradients.training.models.user_models import *
+from super_gradients.training.models.model_factory import get
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
  1. import inspect
  2. from typing import Callable, Optional
  3. from super_gradients.training.models.all_architectures import ARCHITECTURES
  4. def register(name: Optional[str] = None) -> Callable:
  5. def model_decorator(model: Callable) -> Callable:
  6. if name is not None:
  7. model_name = name
  8. else:
  9. model_name = model.__name__
  10. if model_name in ARCHITECTURES:
  11. ref = ARCHITECTURES[model_name]
  12. raise Exception(
  13. f"`{model_name}` is already registered and points to `{inspect.getmodule(ref).__name__}.{ref.__name__}"
  14. )
  15. ARCHITECTURES[model_name] = model
  16. return model
  17. return model_decorator
Discard


Introduction

This page demonstrates how you can register your own models, so that SuperGradients can access it with a name str, for example, when training from a recipe config architecture: my_custom_model.

Usage

  1. Create a new Python module in this folder (e.g. .../user_models/my_model.py).
  2. Define your PyTorch model (torch.nn.Module) in the new module.
  3. Import the @register decorator from super_gradients.training.models.model_registry import register and apply it to your model.
    • The decorator can be applied directly to the class or to a function returning the class.
    • The decorator takes an optional name: str argument. If not specified, the decorated class/function name will be registered.

Example

import torch.nn as nn
import torch.nn.functional as F

from super_gradients.training.models.model_registry import register

@register('my_conv_net') # will be registered as "my_conv_net"
class MyConvNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

or

@register()
def myconvnet_for_cifar10(): # will be registered as "myconvnet_for_cifar10"
    return MyConvNet(num_classes=10)
Discard
1
2
3
4
  1. import os
  2. import pkgutil
  3. __all__ = list(module for _, module, _ in pkgutil.iter_modules([os.path.dirname(__file__)]))
Discard
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  1. import unittest
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from super_gradients.training.models.all_architectures import ARCHITECTURES
  6. from super_gradients.training.models.model_registry import register
  7. class ModelRegistryTest(unittest.TestCase):
  8. def setUp(self):
  9. @register('myconvnet')
  10. class MyConvNet(nn.Module):
  11. def __init__(self, num_classes):
  12. super().__init__()
  13. self.conv1 = nn.Conv2d(3, 6, 5)
  14. self.pool = nn.MaxPool2d(2, 2)
  15. self.conv2 = nn.Conv2d(6, 16, 5)
  16. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  17. self.fc2 = nn.Linear(120, 84)
  18. self.fc3 = nn.Linear(84, num_classes)
  19. def forward(self, x):
  20. x = self.pool(F.relu(self.conv1(x)))
  21. x = self.pool(F.relu(self.conv2(x)))
  22. x = torch.flatten(x, 1)
  23. x = F.relu(self.fc1(x))
  24. x = F.relu(self.fc2(x))
  25. x = self.fc3(x)
  26. return x
  27. @register()
  28. def myconvnet_for_cifar10():
  29. return MyConvNet(num_classes=10)
  30. def tearDown(self):
  31. ARCHITECTURES.pop('myconvnet', None)
  32. ARCHITECTURES.pop('myconvnet_for_cifar10', None)
  33. def test_cls_is_registered(self):
  34. assert ARCHITECTURES['myconvnet']
  35. def test_fn_is_registered(self):
  36. assert ARCHITECTURES['myconvnet_for_cifar10']
  37. def test_model_is_instantiable(self):
  38. assert ARCHITECTURES['myconvnet_for_cifar10']()
  39. assert ARCHITECTURES['myconvnet'](num_classes=10)
  40. def test_model_outputs(self):
  41. torch.manual_seed(0)
  42. model_1 = ARCHITECTURES['myconvnet_for_cifar10']()
  43. torch.manual_seed(0)
  44. model_2 = ARCHITECTURES['myconvnet'](num_classes=10)
  45. dummy_input = torch.randn(1, 3, 32, 32, requires_grad=False)
  46. x = model_1(dummy_input)
  47. y = model_2(dummy_input)
  48. assert torch.equal(x, y)
  49. def test_existing_key(self):
  50. with self.assertRaises(Exception):
  51. @register()
  52. def myconvnet_for_cifar10():
  53. return
  54. if __name__ == '__main__':
  55. unittest.main()
Discard