Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#399 Feature/sg 326 replace function with class in architectures

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/sg-326-replace_function_with_class_in_architectures

How to use your own objects in SuperGradients recipes ?

1. Introduction

To train a model, it is necessary to configure 4 main components. These components are aggregated into a single "main" recipe .yaml file that inherits the aforementioned dataset, architecture, training and checkpoint params.

Recipes support out of the box every model, metric or loss that is implemented in SuperGradients, but you can easily extend this to any custom object that you need by "registering it".

Notes:

2. General flow

In your python script

  1. Define your custom object of type:
    • metric: torchmetrics.Metric
    • model: torch.nn.Module
    • loss: torch.nn.modules.loss._Loss
  2. Import the associated register decorator:
    • metric: from super_gradients.training.utils.registry import register_metric
    • model: from super_gradients.training.utils.registry import register_model
    • loss: coming soon
  3. Apply it on your object.
    • The decorator takes an optional name: str argument. If not specified, the decorated class name will be registered.

In your recipe (.yaml)

  1. Define your recipe like in any other case (you can find examples here).
  2. Modify the recipe by using the registered name (see the following examples).

3. Examples

A. Metric

main.py

import omegaconf
import hydra

import torch
import torchmetrics

from super_gradients import Trainer, init_trainer
from super_gradients.common.registry.registry import register_metric


@register_metric('custom_accuracy')  # Will be registered as "custom_accuracy"
class CustomAccuracy(torchmetrics.Accuracy):
   def update(self, preds: torch.Tensor, target: torch.Tensor):
      if target.shape == preds.shape:
         target = target.argmax(1)  # Supports smooth labels
         super().update(preds=preds.argmax(1), target=target)


@hydra.main(config_path="recipes")
def main(cfg: omegaconf.DictConfig) -> None:
   Trainer.train_from_config(cfg)


init_trainer()
main()

recipes/training_hyperparams/my_training_hyperparams.yaml

... # Other training hyperparams

train_metrics_list:
  - custom_accuracy

valid_metrics_list:
  - custom_accuracy

Launch the script

python main.py --config-name=my_recipe.yaml

B. Model

Coming soon

C. Loss

Coming soon

Discard
Tip!

Press p or to see the previous file or, n or to see the next file