Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#381 Feature/sg 000 connect to lab

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/sg-000_connect_to_lab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
  1. import torch
  2. from torch import nn
  3. import torch.nn.functional as F
  4. def onehot(indexes, N=None, ignore_index=None):
  5. """
  6. Creates a one-hot representation of indexes with N possible entries
  7. if N is not specified, it will suit the maximum index appearing.
  8. indexes is a long-tensor of indexes
  9. ignore_index will be zero in onehot representation
  10. """
  11. if N is None:
  12. N = indexes.max() + 1
  13. sz = list(indexes.size())
  14. output = indexes.new().byte().resize_(*sz, N).zero_()
  15. output.scatter_(-1, indexes.unsqueeze(-1), 1)
  16. if ignore_index is not None and ignore_index >= 0:
  17. output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0)
  18. return output
  19. def _is_long(x):
  20. if hasattr(x, 'data'):
  21. x = x.data
  22. return isinstance(x, torch.LongTensor) or isinstance(x, torch.cuda.LongTensor)
  23. def cross_entropy(inputs, target, weight=None, ignore_index=-100, reduction='mean', # noqa: C901
  24. smooth_eps=None, smooth_dist=None, from_logits=True):
  25. """cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567"""
  26. smooth_eps = smooth_eps or 0
  27. # ordinary log-liklihood - use cross_entropy from nn
  28. if _is_long(target) and smooth_eps == 0:
  29. if from_logits:
  30. return F.cross_entropy(inputs, target, weight, ignore_index=ignore_index, reduction=reduction)
  31. else:
  32. return F.nll_loss(inputs, target, weight, ignore_index=ignore_index, reduction=reduction)
  33. if from_logits:
  34. # log-softmax of inputs
  35. lsm = F.log_softmax(inputs, dim=-1)
  36. else:
  37. lsm = inputs
  38. masked_indices = None
  39. num_classes = inputs.size(-1)
  40. if _is_long(target) and ignore_index >= 0:
  41. masked_indices = target.eq(ignore_index)
  42. if smooth_eps > 0 and smooth_dist is not None:
  43. if _is_long(target):
  44. target = onehot(target, num_classes).type_as(inputs)
  45. if smooth_dist.dim() < target.dim():
  46. smooth_dist = smooth_dist.unsqueeze(0)
  47. target.lerp_(smooth_dist, smooth_eps)
  48. if weight is not None:
  49. lsm = lsm * weight.unsqueeze(0)
  50. if _is_long(target):
  51. eps_nll = 1. - smooth_eps
  52. likelihood = lsm.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1)
  53. loss = -(eps_nll * likelihood + smooth_eps * lsm.mean(-1))
  54. else:
  55. loss = -(target * lsm).sum(-1)
  56. if masked_indices is not None:
  57. loss.masked_fill_(masked_indices, 0)
  58. if reduction == 'sum':
  59. loss = loss.sum()
  60. elif reduction == 'mean':
  61. if masked_indices is None:
  62. loss = loss.mean()
  63. else:
  64. loss = loss.sum() / float(loss.size(0) - masked_indices.sum())
  65. return loss
  66. class LabelSmoothingCrossEntropyLoss(nn.CrossEntropyLoss):
  67. """CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing"""
  68. def __init__(self, weight=None, ignore_index=-100, reduction='mean', smooth_eps=None, smooth_dist=None,
  69. from_logits=True):
  70. super(LabelSmoothingCrossEntropyLoss, self).__init__(weight=weight,
  71. ignore_index=ignore_index, reduction=reduction)
  72. self.smooth_eps = smooth_eps
  73. self.smooth_dist = smooth_dist
  74. self.from_logits = from_logits
  75. def forward(self, input, target, smooth_dist=None):
  76. if smooth_dist is None:
  77. smooth_dist = self.smooth_dist
  78. loss = cross_entropy(input, target, weight=self.weight, ignore_index=self.ignore_index,
  79. reduction=self.reduction, smooth_eps=self.smooth_eps,
  80. smooth_dist=smooth_dist, from_logits=self.from_logits)
  81. # CHANGED TO THE CURRENT FORMAT- OUR CRITERION FUNCTIONS SHOULD ALL NPW RETURN A TUPLE OF (LOSS_FOR_BACKPROP, ADDITIONAL_ITEMS)
  82. # WHERE ADDITIONAL ITEMS ARE TORCH TENSORS OF SIZE (N_ITEMS,...) DETACHED FROM THEIR GRADIENTS FOR LOGGING
  83. return loss, loss.unsqueeze(0).detach()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file