Deci-AI:master
from
deci-ai:feature/sg-326-replace_function_with_class_in_architectures
To train a model, it is necessary to configure 4 main components. These components are aggregated into a single "main" recipe .yaml file that inherits the aforementioned dataset, architecture, training and checkpoint params.
Recipes support out of the box every model, metric or loss that is implemented in SuperGradients, but you can easily extend this to any custom object that you need by "registering it".
Notes:
In your python script
torchmetrics.Metric
torch.nn.Module
torch.nn.modules.loss._Loss
from super_gradients.training.utils.registry import register_metric
from super_gradients.training.utils.registry import register_model
name: str
argument. If not specified, the decorated class name will be registered.In your recipe (.yaml)
main.py
import omegaconf
import hydra
import torch
import torchmetrics
from super_gradients import Trainer, init_trainer
from super_gradients.common.registry.registry import register_metric
@register_metric('custom_accuracy') # Will be registered as "custom_accuracy"
class CustomAccuracy(torchmetrics.Accuracy):
def update(self, preds: torch.Tensor, target: torch.Tensor):
if target.shape == preds.shape:
target = target.argmax(1) # Supports smooth labels
super().update(preds=preds.argmax(1), target=target)
@hydra.main(config_path="recipes")
def main(cfg: omegaconf.DictConfig) -> None:
Trainer.train_from_config(cfg)
init_trainer()
main()
recipes/training_hyperparams/my_training_hyperparams.yaml
... # Other training hyperparams
train_metrics_list:
- custom_accuracy
valid_metrics_list:
- custom_accuracy
Launch the script
python main.py --config-name=my_recipe.yaml
Coming soon
Coming soon
Press p or to see the previous file or, n or to see the next file