Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#970 Update YoloNASQuickstart.md

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/SG-000_fix_readme_yolonas_snippets
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
  1. from __future__ import print_function, absolute_import
  2. import torch
  3. import torch.nn as nn
  4. from torch.nn.modules.loss import _Loss
  5. from super_gradients.common.object_names import Losses
  6. from super_gradients.common.registry.registry import register_loss
  7. from super_gradients.training.utils import convert_to_tensor
  8. @register_loss(Losses.R_SQUARED_LOSS)
  9. class RSquaredLoss(_Loss):
  10. def forward(self, output, target):
  11. # FIXME - THIS NEEDS TO BE CHANGED SUCH THAT THIS CLASS INHERETS FROM _Loss (TAKE A LOOK AT YoLoV3DetectionLoss)
  12. """Computes the R-squared for the output and target values
  13. :param output: Tensor / Numpy / List
  14. The prediction
  15. :param target: Tensor / Numpy / List
  16. The corresponding lables
  17. """
  18. # Convert to tensor
  19. output = convert_to_tensor(output)
  20. target = convert_to_tensor(target)
  21. criterion_mse = nn.MSELoss()
  22. return 1 - criterion_mse(output, target).item() / torch.var(target).item()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file