Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#363 rename register to register_model (to support future registrators)

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-292_rename_register
@@ -4,7 +4,7 @@ from typing import Callable, Optional
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 
 
 
 
-def register(name: Optional[str] = None) -> Callable:
+def register_model(name: Optional[str] = None) -> Callable:
     def model_decorator(model: Callable) -> Callable:
     def model_decorator(model: Callable) -> Callable:
         if name is not None:
         if name is not None:
             model_name = name
             model_name = name
Discard
@@ -19,9 +19,10 @@ example, when training from a recipe config `architecture: my_custom_model`.
 import torch.nn as nn
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
 
 
-from super_gradients.training.models.model_registry import register
+from super_gradients.training.models.model_registry import register_model
 
 
-@register('my_conv_net') # will be registered as "my_conv_net"
+
+@register_model('my_conv_net')  # will be registered as "my_conv_net"
 class MyConvNet(nn.Module):
 class MyConvNet(nn.Module):
     def __init__(self, num_classes):
     def __init__(self, num_classes):
         super().__init__()
         super().__init__()
@@ -43,7 +44,7 @@ class MyConvNet(nn.Module):
 ```
 ```
 or
 or
 ```python
 ```python
-@register()
+@register_model()
 def myconvnet_for_cifar10(): # will be registered as "myconvnet_for_cifar10"
 def myconvnet_for_cifar10(): # will be registered as "myconvnet_for_cifar10"
     return MyConvNet(num_classes=10)
     return MyConvNet(num_classes=10)
 ```
 ```
Discard
@@ -5,13 +5,13 @@ import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
 
 
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.models.all_architectures import ARCHITECTURES
-from super_gradients.training.models.model_registry import register
+from super_gradients.training.models.model_registry import register_model
 
 
 
 
 class ModelRegistryTest(unittest.TestCase):
 class ModelRegistryTest(unittest.TestCase):
 
 
     def setUp(self):
     def setUp(self):
-        @register('myconvnet')
+        @register_model('myconvnet')
         class MyConvNet(nn.Module):
         class MyConvNet(nn.Module):
             def __init__(self, num_classes):
             def __init__(self, num_classes):
                 super().__init__()
                 super().__init__()
@@ -31,7 +31,7 @@ class ModelRegistryTest(unittest.TestCase):
                 x = self.fc3(x)
                 x = self.fc3(x)
                 return x
                 return x
 
 
-        @register()
+        @register_model()
         def myconvnet_for_cifar10():
         def myconvnet_for_cifar10():
             return MyConvNet(num_classes=10)
             return MyConvNet(num_classes=10)
 
 
@@ -61,7 +61,7 @@ class ModelRegistryTest(unittest.TestCase):
 
 
     def test_existing_key(self):
     def test_existing_key(self):
         with self.assertRaises(Exception):
         with self.assertRaises(Exception):
-            @register()
+            @register_model()
             def myconvnet_for_cifar10():
             def myconvnet_for_cifar10():
                 return
                 return
 
 
Discard