|
@@ -5,13 +5,13 @@ import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
|
|
|
|
from super_gradients.training.models.all_architectures import ARCHITECTURES
|
|
from super_gradients.training.models.all_architectures import ARCHITECTURES
|
|
-from super_gradients.training.models.model_registry import register
|
|
|
|
|
|
+from super_gradients.training.models.model_registry import register_model
|
|
|
|
|
|
|
|
|
|
class ModelRegistryTest(unittest.TestCase):
|
|
class ModelRegistryTest(unittest.TestCase):
|
|
|
|
|
|
def setUp(self):
|
|
def setUp(self):
|
|
- @register('myconvnet')
|
|
|
|
|
|
+ @register_model('myconvnet')
|
|
class MyConvNet(nn.Module):
|
|
class MyConvNet(nn.Module):
|
|
def __init__(self, num_classes):
|
|
def __init__(self, num_classes):
|
|
super().__init__()
|
|
super().__init__()
|
|
@@ -31,7 +31,7 @@ class ModelRegistryTest(unittest.TestCase):
|
|
x = self.fc3(x)
|
|
x = self.fc3(x)
|
|
return x
|
|
return x
|
|
|
|
|
|
- @register()
|
|
|
|
|
|
+ @register_model()
|
|
def myconvnet_for_cifar10():
|
|
def myconvnet_for_cifar10():
|
|
return MyConvNet(num_classes=10)
|
|
return MyConvNet(num_classes=10)
|
|
|
|
|
|
@@ -61,7 +61,7 @@ class ModelRegistryTest(unittest.TestCase):
|
|
|
|
|
|
def test_existing_key(self):
|
|
def test_existing_key(self):
|
|
with self.assertRaises(Exception):
|
|
with self.assertRaises(Exception):
|
|
- @register()
|
|
|
|
|
|
+ @register_model()
|
|
def myconvnet_for_cifar10():
|
|
def myconvnet_for_cifar10():
|
|
return
|
|
return
|
|
|
|
|