Are you sure you want to delete this access key?
SuperGradients can support any PyTorch-based loss function. Additionally, multiple Loss function implementations for various tasks are also supported:
cross_entropy
mse
r_squared_loss
shelfnet_ohem_loss
shelfnet_se_loss
yolox_loss
yolox_fast_loss
ssd_loss
stdc_loss
bce_dice_loss
kd_loss
dice_ce_edge_loss
All the above, are just string aliases for the underlying torch.nn.Module classes, implementing the specified loss functions.
The most basic use case is when using a direct Trainer.train(...) call:
In your my_training_script.py
:
...
trainer = Trainer("external_criterion_test")
train_dataloader = ...
valid_dataloader = ...
model = ...
train_params = {
...
"loss": "cross_entropy",
"criterion_params": {}
...
}
trainer.train(model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=valid_dataloader)
Note that object names in SG are not case-sensitive nor symbol-sensitive, so "CrossEntropy
could have been passed as well.
Since most IDEs support auto-completion, for your convenience, you can use our object_names module:
from super_gradients.common.object_names import Losses
Then simply instead of "cross_entropy", use
Losses.CROSS_ENTROPY
Another use case is when using configuration files. For example, when training using train_from_recipe (or similar, when the underlying train method that is being called is Trainer.train_from_config(...)).
When doing so, in your my_training_hyperparams.yaml
file:
...
loss: yolox_loss
criterion_params:
strides: [8, 16, 32] # output strides of all yolo outputs
num_classes: 80
Note that two training_params
parameters define the loss function: loss
which defines the type of the loss, andcriterion_params
dictionary which will be unpacked to the underlying yolox_loss
class constructor.
SuperGradients also supports passing instantiated nn.Module Objects as demonstrated below:
When using a direct Trainer.train(...) call, in your my_training_script.py
simply pass the instantiated nn.Module under the "loss" key inside training_params:
...
trainer = Trainer("external_criterion_test")
train_dataloader = ...
valid_dataloader = ...
model = ...
train_params = {
...
"loss": torch.nn.CrossEntropy()
...
}
trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
Though not as convenient as using register_loss
(discussed further into detail in the next sub-section), one can also equivalently instantiate objects when using train_from_recipe (or similar, when the underlying train method is Trainer.train_from_config(...) as demonstrated below:
In your my_training_hyperparams.yaml
file:
...
loss:
_target_: torch.nn.CrossEntropy
Note that when passing an instantiated loss object, criterion_params
will be ignored.
SuperGradients also supports user-defined loss functions assuming they are torch.nn.Module inheritors, and that their forward
signature is in the form:
import torch.nn
MyLoss(torch.nn.Module):
...
forward(preds, target):
...
And as the argument names suggest, the first argument is the model's output, and target is the label/ground truth (argument naming is arbitrary and does not need to be specifically 'preds' or 'target').
Loss functions accepting additional arguments in their forward
method will be supported in the future.
In the most common case, where the loss function returns a single item for backprop the loss output will appear in the logs, training logs (i.e Tensorboards and any other supported SGLogger, for more information on SGLoggers click here), over epochs under <LOSS_CLASS.name>.
forward(...) should return a (loss, loss_items) tuple where loss is the tensor used for backprop (i.e what your original loss function returns), and loss_items should be a tensor of shape (n_items) consisting of values computed during the forward pass which we desire to log over the entire epoch. For example- the loss itself should always be logged. Another example is a scenario where the computed loss is the sum of a few components we would like to log.
For example:
class MyLoss(_Loss):
...
def forward(self, inputs, targets):
...
total_loss = comp1 + comp2
loss_items = torch.cat((total_loss.unsqueeze(0),comp1.unsqueeze(0), comp2.unsqueeze(0)).detach()
return total_loss, loss_items
train_params = {
...,
"loss": MyLoss(),
"metric_to_watch": "MyLoss/loss_0"
}
Trainer.train(
...,
train_params=train_params
)
The above snippet will log MyLoss2/loss_0
, MyLoss2/loss_1
and MyLoss2/loss_2
as they have been named by their positional index in loss_items.
Note we also defined "MyLoss2/loss_0" to be our watched metric which means we save our checkpoint every epoch we reach the best loss score.
For more visibility, you can also set a "component_names" property in the loss class,
to be a list of strings, of length n_items whose ith element is the name of the ith entry in loss_items.
Then each item will be logged, rendered on the tensorboard, and "watched" (i.e saving model checkpoints
according to it) under <LOSS_CLASS.__name__>/<COMPONENT_NAME>
.
For example:
class MyLoss(_Loss):
...
def forward(self, inputs, targets):
...
total_loss = comp1 + comp2
loss_items = torch.cat((total_loss.unsqueeze(0),comp1.unsqueeze(0), comp2.unsqueeze(0)).detach()
return total_loss, loss_items
...
@property
def component_names(self):
return ["total_loss", "my_1st_component", "my_2nd_component"]
train_params = {
...,
"loss": MyLoss(),
"metric_to_watch": "MyLoss/my_1st_component"
}
Trainer.train(
...,
train_params=train_params
)
The above code will log and monitor MyLoss/total_loss
, MyLoss/my_1st_component
and MyLoss/my_2nd_component
.
Since running logs will save the loss_items in some internal state, it is recommended to detach loss_items from their computational graph for memory efficiency.
When using configuration files, for example, training using train_from_recipe (or similar, when the underlying train method that is being called is Trainer.train_from_config(...)), In your my_loss.py
, register your loss class by decorating the class with register_loss
:
import torch.nn
from super_gradients.common.registry import register_loss
@register_loss("my_loss")
class MyLoss(torch.nn.Module):
...
Then, in your my_training_hyperparams.yaml
, use "my_loss"
in the same way as any other loss supported in SG:
...
loss: my_loss
criterion_params:
...
Last, in your my_train_from_recipe_script.py
file, just import the newly registered class (even though the class itself is unused, just to trigger the registry):
from omegaconf import DictConfig
import hydra
import pkg_resources
from my_loss import MyLoss
from super_gradients import Trainer, init_trainer
@hydra.main(config_path=pkg_resources.resource_filename("super_gradients.recipes", ""), version_base="1.2")
def main(cfg: DictConfig) -> None:
Trainer.train_from_config(cfg)
def run():
init_trainer()
main()
if __name__ == "__main__":
run()
Press p or to see the previous file or, n or to see the next file
Browsing data directories saved to S3 is possible with DAGsHub. Let's configure your repository to easily display your data in the context of any commit!
super-gradients is now integrated with AWS S3!
Are you sure you want to delete this access key?
Browsing data directories saved to Google Cloud Storage is possible with DAGsHub. Let's configure your repository to easily display your data in the context of any commit!
super-gradients is now integrated with Google Cloud Storage!
Are you sure you want to delete this access key?
Browsing data directories saved to Azure Cloud Storage is possible with DAGsHub. Let's configure your repository to easily display your data in the context of any commit!
super-gradients is now integrated with Azure Cloud Storage!
Are you sure you want to delete this access key?
Browsing data directories saved to S3 compatible storage is possible with DAGsHub. Let's configure your repository to easily display your data in the context of any commit!
super-gradients is now integrated with your S3 compatible storage!
Are you sure you want to delete this access key?