Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

loss.py 2.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
  1. """
  2. The loss must be of torch.nn.modules.loss._Loss class.
  3. For commonly used losses, import from deci.core.ADNN.losses
  4. -IMPORTANT: forward(...) should return (loss, loss_items) where loss is the tensor used for backprop (i.e what your
  5. original loss function returns), and loss_items should be a tensor of shape (n_items), of values computed during
  6. the forward pass which we desire to log over the entire epoch. For example- the loss itself should always be logged.
  7. Another examploe is a scenario where the computed loss is the sum of a few components we would like to log- these
  8. entries in loss_items).
  9. -When training, set the "loss_logging_items_names" parameter in train_params to be a list of strings, of length
  10. n_items who's ith element is the name of the ith entry in loss_items. Then each item will be logged, rendered on
  11. tensorboard and "watched" (i.e saving model checkpoints according to it).
  12. -Since running logs will save the loss_items in some internal state, it is recommended that loss_items are detached
  13. from their computational graph for memory efficiency.
  14. """
  15. import torch.nn as nn
  16. from super_gradients.training.losses.label_smoothing_cross_entropy_loss import cross_entropy
  17. class LabelSmoothingCrossEntropyLoss(nn.CrossEntropyLoss):
  18. """
  19. LabelSmoothingCrossEntropyLoss - POC loss class, uses SuperGradient's cross entropy which support distribution as targets.
  20. """
  21. def __init__(self, weight=None, ignore_index=-100, reduction='mean', smooth_eps=None, smooth_dist=None,
  22. from_logits=True):
  23. super(LabelSmoothingCrossEntropyLoss, self).__init__(weight=weight,
  24. ignore_index=ignore_index, reduction=reduction)
  25. self.smooth_eps = smooth_eps
  26. self.smooth_dist = smooth_dist
  27. self.from_logits = from_logits
  28. def forward(self, input, target, smooth_dist=None):
  29. if smooth_dist is None:
  30. smooth_dist = self.smooth_dist
  31. loss = cross_entropy(input, target, weight=self.weight, ignore_index=self.ignore_index,
  32. reduction=self.reduction, smooth_eps=self.smooth_eps,
  33. smooth_dist=smooth_dist, from_logits=self.from_logits)
  34. loss_items = loss.detach().unsqueeze(0)
  35. return loss, loss_items
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...