Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#406 Feature/SG 293 - Add register loss

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-293-register_loss
@@ -85,4 +85,44 @@ python main.py --config-name=my_recipe.yaml
 Coming soon
 Coming soon
 
 
 ### C. Loss
 ### C. Loss
-Coming soon
+
+*main.py*
+
+```python
+import omegaconf
+import hydra
+
+import torch
+
+from super_gradients import Trainer, init_trainer
+from super_gradients.common.registry.registry import register_loss
+
+
+@register_loss("custom_rsquared_loss")
+class CustomRSquaredLoss(torch.nn.modules.loss._Loss): # The Loss needs to inherit from torch _Loss class.
+   def forward(self, output, target):
+       criterion_mse = torch.nn.MSELoss()
+       return 1 - criterion_mse(output, target).item() / torch.var(target).item()
+
+
+@hydra.main(config_path="recipes")
+def main(cfg: omegaconf.DictConfig) -> None:
+   Trainer.train_from_config(cfg)
+
+
+init_trainer()
+main()
+```
+
+*recipes/training_hyperparams/my_training_hyperparams.yaml* 
+```yaml
+... # Other training hyperparams
+
+loss: custom_rsquared_loss
+```
+
+*Launch the script*
+```bash
+python main.py --config-name=my_recipe.yaml
+```
+
Discard
@@ -1,4 +1,4 @@
-from super_gradients.common.registry.registry import register_model, register_metric
+from super_gradients.common.registry.registry import register_model, register_metric, register_loss
 
 
 
 
-__all__ = ['register_model', 'register_metric']
+__all__ = ['register_model', 'register_metric', 'register_loss']
Discard
@@ -3,6 +3,7 @@ from typing import Callable, Dict, Optional
 
 
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.metrics.all_metrics import METRICS
 from super_gradients.training.metrics.all_metrics import METRICS
+from super_gradients.training.losses.all_losses import LOSSES
 
 
 
 
 def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
 def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
@@ -35,3 +36,4 @@ def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
 
 
 register_model = create_register_decorator(registry=ARCHITECTURES)
 register_model = create_register_decorator(registry=ARCHITECTURES)
 register_metric = create_register_decorator(registry=METRICS)
 register_metric = create_register_decorator(registry=METRICS)
+register_loss = create_register_decorator(registry=LOSSES)
Discard
@@ -4,10 +4,12 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
 import torchmetrics
 import torchmetrics
+from torch.nn.modules.loss import _Loss
 
 
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.models.all_architectures import ARCHITECTURES
 from super_gradients.training.metrics.all_metrics import METRICS
 from super_gradients.training.metrics.all_metrics import METRICS
-from super_gradients.common.registry import register_model, register_metric
+from super_gradients.training.losses.all_losses import LOSSES
+from super_gradients.common.registry import register_model, register_metric, register_loss
 
 
 
 
 class RegistryTest(unittest.TestCase):
 class RegistryTest(unittest.TestCase):
@@ -45,14 +47,22 @@ class RegistryTest(unittest.TestCase):
                     target = target.argmax(1)  # Supports smooth labels
                     target = target.argmax(1)  # Supports smooth labels
                     super().update(preds=preds.argmax(1), target=target)
                     super().update(preds=preds.argmax(1), target=target)
 
 
+        @register_loss("custom_rsquared_loss")
+        class CustomRSquaredLoss(_Loss):
+            def forward(self, output, target):
+                criterion_mse = nn.MSELoss()
+                return 1 - criterion_mse(output, target).item() / torch.var(target).item()
+
     def tearDown(self):
     def tearDown(self):
         ARCHITECTURES.pop('myconvnet', None)
         ARCHITECTURES.pop('myconvnet', None)
         ARCHITECTURES.pop('myconvnet_for_cifar10', None)
         ARCHITECTURES.pop('myconvnet_for_cifar10', None)
         METRICS.pop('custom_accuracy', None)
         METRICS.pop('custom_accuracy', None)
+        LOSSES.pop('custom_rsquared_loss', None)
 
 
     def test_cls_is_registered(self):
     def test_cls_is_registered(self):
         assert ARCHITECTURES['myconvnet']
         assert ARCHITECTURES['myconvnet']
         assert METRICS['custom_accuracy']
         assert METRICS['custom_accuracy']
+        assert LOSSES['custom_rsquared_loss']
 
 
     def test_fn_is_registered(self):
     def test_fn_is_registered(self):
         assert ARCHITECTURES['myconvnet_for_cifar10']
         assert ARCHITECTURES['myconvnet_for_cifar10']
@@ -61,6 +71,7 @@ class RegistryTest(unittest.TestCase):
         assert ARCHITECTURES['myconvnet_for_cifar10']()
         assert ARCHITECTURES['myconvnet_for_cifar10']()
         assert ARCHITECTURES['myconvnet'](num_classes=10)
         assert ARCHITECTURES['myconvnet'](num_classes=10)
         assert METRICS['custom_accuracy']()
         assert METRICS['custom_accuracy']()
+        assert LOSSES['custom_rsquared_loss']()
 
 
     def test_model_outputs(self):
     def test_model_outputs(self):
         torch.manual_seed(0)
         torch.manual_seed(0)
Discard