Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#378 Feature/sg 281 add kd notebook

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-281-add_kd_notebook
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
  1. from __future__ import print_function, absolute_import
  2. import torch
  3. import torch.nn as nn
  4. from torch.nn.modules.loss import _Loss
  5. from super_gradients.training.utils import convert_to_tensor
  6. class RSquaredLoss(_Loss):
  7. def forward(self, output, target):
  8. # FIXME - THIS NEEDS TO BE CHANGED SUCH THAT THIS CLASS INHERETS FROM _Loss (TAKE A LOOK AT YoLoV3DetectionLoss)
  9. """Computes the R-squared for the output and target values
  10. :param output: Tensor / Numpy / List
  11. The prediction
  12. :param target: Tensor / Numpy / List
  13. The corresponding lables
  14. """
  15. # Convert to tensor
  16. output = convert_to_tensor(output)
  17. target = convert_to_tensor(target)
  18. criterion_mse = nn.MSELoss()
  19. return 1 - criterion_mse(output, target).item() / torch.var(target).item()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file