Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

metrics.py 895 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
  1. """
  2. This file is used to define the Metrics used for training.
  3. The metrics object must be of torchmetrics.Metric type. For more information on how to use torchmetric.Metric objects and
  4. implement your own metrics see https://torchmetrics.readthedocs.io/en/latest/pages/overview.html
  5. """
  6. import torchmetrics
  7. import torch
  8. class Accuracy(torchmetrics.Accuracy):
  9. def __init__(self, dist_sync_on_step=False):
  10. super().__init__(dist_sync_on_step=dist_sync_on_step, top_k=1)
  11. def update(self, preds: torch.Tensor, target: torch.Tensor):
  12. super().update(preds=preds.softmax(1), target=target)
  13. class Top5(torchmetrics.Accuracy):
  14. def __init__(self, dist_sync_on_step=False):
  15. super().__init__(dist_sync_on_step=dist_sync_on_step, top_k=5)
  16. def update(self, preds: torch.Tensor, target: torch.Tensor):
  17. super().update(preds=preds.softmax(1), target=target)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...