|
@@ -4,10 +4,12 @@ import torch
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
import torchmetrics
|
|
import torchmetrics
|
|
|
|
+from torch.nn.modules.loss import _Loss
|
|
|
|
|
|
from super_gradients.training.models.all_architectures import ARCHITECTURES
|
|
from super_gradients.training.models.all_architectures import ARCHITECTURES
|
|
from super_gradients.training.metrics.all_metrics import METRICS
|
|
from super_gradients.training.metrics.all_metrics import METRICS
|
|
-from super_gradients.common.registry import register_model, register_metric
|
|
|
|
|
|
+from super_gradients.training.losses.all_losses import LOSSES
|
|
|
|
+from super_gradients.common.registry import register_model, register_metric, register_loss
|
|
|
|
|
|
|
|
|
|
class RegistryTest(unittest.TestCase):
|
|
class RegistryTest(unittest.TestCase):
|
|
@@ -45,14 +47,22 @@ class RegistryTest(unittest.TestCase):
|
|
target = target.argmax(1) # Supports smooth labels
|
|
target = target.argmax(1) # Supports smooth labels
|
|
super().update(preds=preds.argmax(1), target=target)
|
|
super().update(preds=preds.argmax(1), target=target)
|
|
|
|
|
|
|
|
+ @register_loss("custom_rsquared_loss")
|
|
|
|
+ class CustomRSquaredLoss(_Loss):
|
|
|
|
+ def forward(self, output, target):
|
|
|
|
+ criterion_mse = nn.MSELoss()
|
|
|
|
+ return 1 - criterion_mse(output, target).item() / torch.var(target).item()
|
|
|
|
+
|
|
def tearDown(self):
|
|
def tearDown(self):
|
|
ARCHITECTURES.pop('myconvnet', None)
|
|
ARCHITECTURES.pop('myconvnet', None)
|
|
ARCHITECTURES.pop('myconvnet_for_cifar10', None)
|
|
ARCHITECTURES.pop('myconvnet_for_cifar10', None)
|
|
METRICS.pop('custom_accuracy', None)
|
|
METRICS.pop('custom_accuracy', None)
|
|
|
|
+ LOSSES.pop('custom_rsquared_loss', None)
|
|
|
|
|
|
def test_cls_is_registered(self):
|
|
def test_cls_is_registered(self):
|
|
assert ARCHITECTURES['myconvnet']
|
|
assert ARCHITECTURES['myconvnet']
|
|
assert METRICS['custom_accuracy']
|
|
assert METRICS['custom_accuracy']
|
|
|
|
+ assert LOSSES['custom_rsquared_loss']
|
|
|
|
|
|
def test_fn_is_registered(self):
|
|
def test_fn_is_registered(self):
|
|
assert ARCHITECTURES['myconvnet_for_cifar10']
|
|
assert ARCHITECTURES['myconvnet_for_cifar10']
|
|
@@ -61,6 +71,7 @@ class RegistryTest(unittest.TestCase):
|
|
assert ARCHITECTURES['myconvnet_for_cifar10']()
|
|
assert ARCHITECTURES['myconvnet_for_cifar10']()
|
|
assert ARCHITECTURES['myconvnet'](num_classes=10)
|
|
assert ARCHITECTURES['myconvnet'](num_classes=10)
|
|
assert METRICS['custom_accuracy']()
|
|
assert METRICS['custom_accuracy']()
|
|
|
|
+ assert LOSSES['custom_rsquared_loss']()
|
|
|
|
|
|
def test_model_outputs(self):
|
|
def test_model_outputs(self):
|
|
torch.manual_seed(0)
|
|
torch.manual_seed(0)
|