Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

adaptive_input.py 2.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import torch
  8. from torch import nn
  9. from typing import List
  10. class AdaptiveInput(nn.Module):
  11. def __init__(
  12. self,
  13. vocab_size: int,
  14. padding_idx: int,
  15. initial_dim: int,
  16. factor: float,
  17. output_dim: int,
  18. cutoff: List[int],
  19. ):
  20. super().__init__()
  21. if vocab_size > cutoff[-1]:
  22. cutoff = cutoff + [vocab_size]
  23. else:
  24. assert vocab_size == cutoff[
  25. -1], 'cannot specify cutoff larger than vocab size'
  26. self.cutoff = cutoff
  27. self.embedding_dim = output_dim
  28. self.padding_idx = padding_idx
  29. self.embeddings = nn.ModuleList()
  30. for i in range(len(self.cutoff)):
  31. prev = self.cutoff[i - 1] if i > 0 else 0
  32. size = self.cutoff[i] - prev
  33. dim = int(initial_dim // (factor ** i))
  34. seq = nn.Sequential(
  35. nn.Embedding(size, dim, padding_idx),
  36. nn.Linear(dim, output_dim, bias=False)
  37. )
  38. self.embeddings.append(seq)
  39. def init_weights(m):
  40. if isinstance(m, nn.Embedding):
  41. nn.init.normal_(m.weight, mean=0, std=m.weight.shape[1] ** -0.5)
  42. nn.init.constant_(m.weight[padding_idx], 0)
  43. elif hasattr(m, 'weight'):
  44. nn.init.xavier_uniform_(m.weight)
  45. self.apply(init_weights)
  46. self.register_buffer('_float_tensor', torch.FloatTensor(1))
  47. def weights_for_band(self, band: int):
  48. return self.embeddings[band][0].weight, self.embeddings[band][1].weight
  49. def forward(self, input: torch.Tensor):
  50. result = self._float_tensor.new(input.shape + (self.embedding_dim,))
  51. for i in range(len(self.cutoff)):
  52. mask = input.lt(self.cutoff[i])
  53. if i > 0:
  54. mask.mul_(input.ge(self.cutoff[i - 1]))
  55. chunk_input = input[mask] - self.cutoff[i - 1]
  56. else:
  57. chunk_input = input[mask]
  58. if mask.any():
  59. result[mask] = self.embeddings[i](chunk_input)
  60. return result
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...