1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
|
- # Copyright 2023 (authors: Feiteng Li)
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- import torch
- import torch.nn as nn
- class TokenEmbedding(nn.Module):
- def __init__(
- self,
- dim_model: int,
- vocab_size: int,
- dropout: float = 0.0,
- ):
- super().__init__()
- self.vocab_size = vocab_size
- self.dim_model = dim_model
- self.dropout = torch.nn.Dropout(p=dropout)
- self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
- @property
- def weight(self) -> torch.Tensor:
- return self.word_embeddings.weight
- def embedding(self, index: int) -> torch.Tensor:
- return self.word_embeddings.weight[index : index + 1]
- def forward(self, x: torch.Tensor):
- X = self.word_embeddings(x)
- X = self.dropout(X)
- return X
- class SinePositionalEmbedding(nn.Module):
- def __init__(
- self,
- dim_model: int,
- dropout: float = 0.0,
- scale: bool = False,
- alpha: bool = False,
- ):
- super().__init__()
- self.dim_model = dim_model
- self.x_scale = math.sqrt(dim_model) if scale else 1.0
- self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
- self.dropout = torch.nn.Dropout(p=dropout)
- self.reverse = False
- self.pe = None
- self.extend_pe(torch.tensor(0.0).expand(1, 4000))
- def extend_pe(self, x):
- """Reset the positional encodings."""
- if self.pe is not None:
- if self.pe.size(1) >= x.size(1):
- if self.pe.dtype != x.dtype or self.pe.device != x.device:
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
- return
- pe = torch.zeros(x.size(1), self.dim_model)
- if self.reverse:
- position = torch.arange(
- x.size(1) - 1, -1, -1.0, dtype=torch.float32
- ).unsqueeze(1)
- else:
- position = torch.arange(
- 0, x.size(1), dtype=torch.float32
- ).unsqueeze(1)
- div_term = torch.exp(
- torch.arange(0, self.dim_model, 2, dtype=torch.float32)
- * -(math.log(10000.0) / self.dim_model)
- )
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
- pe = pe.unsqueeze(0)
- self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- self.extend_pe(x)
- output = x.unsqueeze(-1) if x.ndim == 2 else x
- output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
- return self.dropout(output)
|