Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

embedding.py 3.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
  1. # Copyright 2023 (authors: Feiteng Li)
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import torch
  16. import torch.nn as nn
  17. class TokenEmbedding(nn.Module):
  18. def __init__(
  19. self,
  20. dim_model: int,
  21. vocab_size: int,
  22. dropout: float = 0.0,
  23. ):
  24. super().__init__()
  25. self.vocab_size = vocab_size
  26. self.dim_model = dim_model
  27. self.dropout = torch.nn.Dropout(p=dropout)
  28. self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
  29. @property
  30. def weight(self) -> torch.Tensor:
  31. return self.word_embeddings.weight
  32. def embedding(self, index: int) -> torch.Tensor:
  33. return self.word_embeddings.weight[index : index + 1]
  34. def forward(self, x: torch.Tensor):
  35. X = self.word_embeddings(x)
  36. X = self.dropout(X)
  37. return X
  38. class SinePositionalEmbedding(nn.Module):
  39. def __init__(
  40. self,
  41. dim_model: int,
  42. dropout: float = 0.0,
  43. scale: bool = False,
  44. alpha: bool = False,
  45. ):
  46. super().__init__()
  47. self.dim_model = dim_model
  48. self.x_scale = math.sqrt(dim_model) if scale else 1.0
  49. self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
  50. self.dropout = torch.nn.Dropout(p=dropout)
  51. self.reverse = False
  52. self.pe = None
  53. self.extend_pe(torch.tensor(0.0).expand(1, 4000))
  54. def extend_pe(self, x):
  55. """Reset the positional encodings."""
  56. if self.pe is not None:
  57. if self.pe.size(1) >= x.size(1):
  58. if self.pe.dtype != x.dtype or self.pe.device != x.device:
  59. self.pe = self.pe.to(dtype=x.dtype, device=x.device)
  60. return
  61. pe = torch.zeros(x.size(1), self.dim_model)
  62. if self.reverse:
  63. position = torch.arange(
  64. x.size(1) - 1, -1, -1.0, dtype=torch.float32
  65. ).unsqueeze(1)
  66. else:
  67. position = torch.arange(
  68. 0, x.size(1), dtype=torch.float32
  69. ).unsqueeze(1)
  70. div_term = torch.exp(
  71. torch.arange(0, self.dim_model, 2, dtype=torch.float32)
  72. * -(math.log(10000.0) / self.dim_model)
  73. )
  74. pe[:, 0::2] = torch.sin(position * div_term)
  75. pe[:, 1::2] = torch.cos(position * div_term)
  76. pe = pe.unsqueeze(0)
  77. self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
  78. def forward(self, x: torch.Tensor) -> torch.Tensor:
  79. self.extend_pe(x)
  80. output = x.unsqueeze(-1) if x.ndim == 2 else x
  81. output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
  82. return self.dropout(output)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...