Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

functions.py 27 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
  1. import math
  2. from typing import Optional, Tuple, Any
  3. from typing import List, Optional, Tuple
  4. import torch.nn as nn
  5. import torch
  6. import torch.nn.functional as F
  7. from torch.nn.parameter import Parameter
  8. from torch.nn.init import constant_, xavier_normal_
  9. from torch.nn.init import xavier_uniform_
  10. import warnings
  11. Tensor = torch.Tensor
  12. def positional_encoding(X, num_features, dropout_p=0.0, max_len=512) -> Tensor:
  13. r'''
  14. 给输入加入位置编码
  15. 参数:
  16. - num_features: 输入进来的维度
  17. - dropout_p: dropout的概率,当其为非零元素时执行dropout
  18. - max_len: 句子的最大长度,默认512
  19. 形状:
  20. - 输入: [batch_size, seq_length, num_features]
  21. - 输出: [batch_size, seq_length, num_features]
  22. 例子:
  23. >>> X = torch.randn((2,4,10))
  24. >>> X = positional_encoding(X, 10)
  25. >>> print(X.shape)
  26. >>> torch.Size([2, 4, 10])
  27. '''
  28. dropout = nn.Dropout(dropout_p)
  29. P = torch.zeros((1,max_len,num_features))
  30. X_ = torch.arange(max_len,dtype=torch.float32).reshape(-1,1) / torch.pow(
  31. 10000,
  32. torch.arange(0,num_features,2,dtype=torch.float32) /num_features)
  33. P[:,:,0::2] = torch.sin(X_)
  34. P[:,:,1::2] = torch.cos(X_)
  35. X = X + P[:,:X.shape[1],:].to(X.device)
  36. return dropout(X)
  37. def _in_projection_packed(
  38. q: Tensor,
  39. k: Tensor,
  40. v: Tensor,
  41. w: Tensor,
  42. b: Optional[Tensor] = None,
  43. ) -> List[Tensor]:
  44. r"""
  45. 用一个大的权重参数矩阵进行线性变换
  46. 参数:
  47. q, k, v: 对自注意来说,三者都是src;对于seq2seq模型,k和v是一致的tensor。
  48. 但它们的最后一维(num_features或者叫做embed_dim)都必须保持一致。
  49. w: 用以线性变换的大矩阵,按照q,k,v的顺序压在一个tensor里面。
  50. b: 用以线性变换的偏置,按照q,k,v的顺序压在一个tensor里面。
  51. 形状:
  52. 输入:
  53. - q: shape:`(..., E)`,E是词嵌入的维度(下面出现的E均为此意)。
  54. - k: shape:`(..., E)`
  55. - v: shape:`(..., E)`
  56. - w: shape:`(E * 3, E)`
  57. - b: shape:`E * 3`
  58. 输出:
  59. - 输出列表 :`[q', k', v']`,q,k,v经过线性变换前后的形状都一致。
  60. """
  61. E = q.size(-1)
  62. # 若为自注意,则q = k = v = src,因此它们的引用变量都是src
  63. # 即k is v和q is k结果均为True
  64. # 若为seq2seq,k = v,因而k is v的结果是True
  65. if k is v:
  66. if q is k:
  67. # 自注意
  68. return nn.functional.linear(q, w, b).chunk(3, dim=-1)
  69. else:
  70. # seq2seq模型
  71. w_q, w_kv = w.split([E, E * 2])
  72. if b is None:
  73. b_q = b_kv = None
  74. else:
  75. b_q, b_kv = b.split([E, E * 2])
  76. return (nn.functional.linear(q, w_q, b_q),) + nn.functional.linear(k, w_kv, b_kv).chunk(2, dim=-1)
  77. else:
  78. w_q, w_k, w_v = w.chunk(3)
  79. if b is None:
  80. b_q = b_k = b_v = None
  81. else:
  82. b_q, b_k, b_v = b.chunk(3)
  83. return nn.functional.linear(q, w_q, b_q), nn.functional.linear(k, w_k, b_k), nn.functional.linear(v, w_v, b_v)
  84. def _scaled_dot_product_attention(
  85. q: Tensor,
  86. k: Tensor,
  87. v: Tensor,
  88. attn_mask: Optional[Tensor] = None,
  89. dropout_p: float = 0.0,
  90. ) -> Tuple[Tensor, Tensor]:
  91. r'''
  92. 在query, key, value上计算点积注意力,若有注意力遮盖则使用,并且应用一个概率为dropout_p的dropout
  93. 参数:
  94. - q: shape:`(B, Nt, E)` B代表batch size, Nt是目标语言序列长度,E是嵌入后的特征维度
  95. - key: shape:`(B, Ns, E)` Ns是源语言序列长度
  96. - value: shape:`(B, Ns, E)`与key形状一样
  97. - attn_mask: 要么是3D的tensor,形状为:`(B, Nt, Ns)`或者2D的tensor,形状如:`(Nt, Ns)`
  98. - Output: attention values: shape:`(B, Nt, E)`,与q的形状一致;attention weights: shape:`(B, Nt, Ns)`
  99. 例子:
  100. >>> q = torch.randn((2,3,6))
  101. >>> k = torch.randn((2,4,6))
  102. >>> v = torch.randn((2,4,6))
  103. >>> out = scaled_dot_product_attention(q, k, v)
  104. >>> out[0].shape, out[1].shape
  105. >>> torch.Size([2, 3, 6]) torch.Size([2, 3, 4])
  106. '''
  107. B, Nt, E = q.shape
  108. q = q / math.sqrt(E)
  109. # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
  110. attn = torch.bmm(q, k.transpose(-2, -1))
  111. if attn_mask is not None:
  112. attn += attn_mask
  113. # attn意味着目标序列的每个词对源语言序列做注意力
  114. attn = nn.functional.softmax(attn, dim=-1)
  115. if dropout_p > 0.0:
  116. attn = nn.functional.dropout(attn, p=dropout_p)
  117. # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
  118. output = torch.bmm(attn, v)
  119. return output, attn
  120. def multi_head_attention_forward(
  121. query: Tensor,
  122. key: Tensor,
  123. value: Tensor,
  124. num_heads: int,
  125. in_proj_weight: Tensor,
  126. in_proj_bias: Optional[Tensor],
  127. dropout_p: float,
  128. out_proj_weight: Tensor,
  129. out_proj_bias: Optional[Tensor],
  130. training: bool = True,
  131. key_padding_mask: Optional[Tensor] = None,
  132. need_weights: bool = True,
  133. attn_mask: Optional[Tensor] = None,
  134. use_separate_proj_weight: bool = False,
  135. q_proj_weight: Optional[Tensor] = None,
  136. k_proj_weight: Optional[Tensor] = None,
  137. v_proj_weight: Optional[Tensor] = None,
  138. ) -> Tuple[Tensor, Optional[Tensor]]:
  139. r'''
  140. 形状:
  141. 输入:
  142. - query:`(L, N, E)`
  143. - key: `(S, N, E)`
  144. - value: `(S, N, E)`
  145. - key_padding_mask: `(N, S)`
  146. - attn_mask: `(L, S)` or `(N * num_heads, L, S)`
  147. 输出:
  148. - attn_output:`(L, N, E)`
  149. - attn_output_weights:`(N, L, S)`
  150. '''
  151. tgt_len, bsz, embed_dim = query.shape
  152. src_len, _, _ = key.shape
  153. head_dim = embed_dim // num_heads
  154. q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
  155. if attn_mask is not None:
  156. if attn_mask.dtype == torch.uint8:
  157. warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
  158. attn_mask = attn_mask.to(torch.bool)
  159. else:
  160. assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
  161. f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
  162. if attn_mask.dim() == 2:
  163. correct_2d_size = (tgt_len, src_len)
  164. if attn_mask.shape != correct_2d_size:
  165. raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
  166. attn_mask = attn_mask.unsqueeze(0)
  167. elif attn_mask.dim() == 3:
  168. correct_3d_size = (bsz * num_heads, tgt_len, src_len)
  169. if attn_mask.shape != correct_3d_size:
  170. raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
  171. else:
  172. raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
  173. if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
  174. warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
  175. key_padding_mask = key_padding_mask.to(torch.bool)
  176. # reshape q,k,v将Batch放在第一维以适合点积注意力
  177. # 同时为多头机制,将不同的头拼在一起组成一层
  178. q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
  179. k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
  180. v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
  181. if key_padding_mask is not None:
  182. assert key_padding_mask.shape == (bsz, src_len), \
  183. f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
  184. key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
  185. expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
  186. if attn_mask is None:
  187. attn_mask = key_padding_mask
  188. elif attn_mask.dtype == torch.bool:
  189. attn_mask = attn_mask.logical_or(key_padding_mask)
  190. else:
  191. attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
  192. # 若attn_mask值是布尔值,则将mask转换为float
  193. if attn_mask is not None and attn_mask.dtype == torch.bool:
  194. new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float)
  195. new_attn_mask.masked_fill_(attn_mask, float("-inf"))
  196. attn_mask = new_attn_mask
  197. # 若training为True时才应用dropout
  198. if not training:
  199. dropout_p = 0.0
  200. attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
  201. attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
  202. attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
  203. if need_weights:
  204. # average attention weights over heads
  205. attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
  206. return attn_output, attn_output_weights.sum(dim=1) / num_heads
  207. else:
  208. return attn_output, None
  209. class MultiheadAttention(nn.Module):
  210. r'''
  211. 参数:
  212. embed_dim: 词嵌入的维度
  213. num_heads: 平行头的数量
  214. batch_first: 若`True`,则为(batch, seq, feture),若为`False`,则为(seq, batch, feature)
  215. 例子:
  216. >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
  217. >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
  218. '''
  219. def __init__(self, embed_dim, num_heads, dropout=0., bias=True, kdim=None, vdim=None,
  220. batch_first=False) -> None:
  221. super(MultiheadAttention, self).__init__()
  222. self.embed_dim = embed_dim
  223. self.kdim = kdim if kdim is not None else embed_dim
  224. self.vdim = vdim if vdim is not None else embed_dim
  225. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  226. self.num_heads = num_heads
  227. self.dropout = dropout
  228. self.batch_first = batch_first
  229. self.head_dim = embed_dim // num_heads
  230. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  231. if self._qkv_same_embed_dim is False:
  232. self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim)))
  233. self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim)))
  234. self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim)))
  235. self.register_parameter('in_proj_weight', None)
  236. else:
  237. self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim)))
  238. self.register_parameter('q_proj_weight', None)
  239. self.register_parameter('k_proj_weight', None)
  240. self.register_parameter('v_proj_weight', None)
  241. if bias:
  242. self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
  243. else:
  244. self.register_parameter('in_proj_bias', None)
  245. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  246. self._reset_parameters()
  247. def _reset_parameters(self):
  248. if self._qkv_same_embed_dim:
  249. xavier_uniform_(self.in_proj_weight)
  250. else:
  251. xavier_uniform_(self.q_proj_weight)
  252. xavier_uniform_(self.k_proj_weight)
  253. xavier_uniform_(self.v_proj_weight)
  254. if self.in_proj_bias is not None:
  255. constant_(self.in_proj_bias, 0.)
  256. constant_(self.out_proj.bias, 0.)
  257. def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
  258. need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
  259. if self.batch_first:
  260. query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
  261. if not self._qkv_same_embed_dim:
  262. attn_output, attn_output_weights = multi_head_attention_forward(
  263. query, key, value, self.num_heads,
  264. self.in_proj_weight, self.in_proj_bias,
  265. self.dropout, self.out_proj.weight, self.out_proj.bias,
  266. key_padding_mask=key_padding_mask, need_weights=need_weights,
  267. attn_mask=attn_mask, use_separate_proj_weight=True,
  268. q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
  269. v_proj_weight=self.v_proj_weight)
  270. else:
  271. attn_output, attn_output_weights = multi_head_attention_forward(
  272. query, key, value,self.num_heads,
  273. self.in_proj_weight, self.in_proj_bias,
  274. self.dropout, self.out_proj.weight, self.out_proj.bias,
  275. key_padding_mask=key_padding_mask, need_weights=need_weights,
  276. attn_mask=attn_mask)
  277. if self.batch_first:
  278. return attn_output.transpose(1, 0), attn_output_weights
  279. else:
  280. return attn_output, attn_output_weights
  281. # src = torch.randn((2,4,100))
  282. # src = positional_encoding(src,100,0.1)
  283. # print(src.shape)
  284. # multihead_attn = MultiheadAttention(100, 4, 0.1)
  285. # attn_output, attn_output_weights = multihead_attn(src,src,src)
  286. # print(attn_output.shape, attn_output_weights.shape)
  287. class TransformerEncoderLayer(nn.Module):
  288. r'''
  289. 参数:
  290. d_model: 词嵌入的维度
  291. nhead: 多头注意力中平行头的数目
  292. dim_feedforward: 全连接层的神经元的数目,又称经过此层输入的维度(Default = 2048)
  293. dropout: dropout的概率
  294. activation: 两个线性层中间的激活函数,默认relu或gelu
  295. lay_norm_eps: layer normalization中的微小量,防止分母为0
  296. batch_first: 若`True`,则为(batch, seq, feture),若为`False`,则为(seq, batch, feature)
  297. 例子:
  298. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
  299. >>> src = torch.randn((32, 10, 512))
  300. >>> out = encoder_layer(src)
  301. '''
  302. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu,
  303. layer_norm_eps=1e-5, batch_first=False) -> None:
  304. super(TransformerEncoderLayer, self).__init__()
  305. self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
  306. self.linear1 = nn.Linear(d_model, dim_feedforward)
  307. self.dropout = nn.Dropout(dropout)
  308. self.linear2 = nn.Linear(dim_feedforward, d_model)
  309. self.activation = activation
  310. self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
  311. self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
  312. self.dropout1 = nn.Dropout(dropout)
  313. self.dropout2 = nn.Dropout(dropout)
  314. def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  315. src2 = self.self_attn(src, src, src, attn_mask=src_mask,
  316. key_padding_mask=src_key_padding_mask)[0]
  317. src = src + self.dropout1(src2)
  318. src = self.norm1(src)
  319. src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
  320. src = src + self.dropout(src2)
  321. src = self.norm2(src)
  322. return src
  323. # encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
  324. # src = torch.randn((32, 10, 512))
  325. # out = encoder_layer(src)
  326. # print(out.shape)
  327. # torch.Size([32, 10, 512])
  328. class TransformerDecoderLayer(nn.Module):
  329. r'''
  330. 参数:
  331. d_model: 词嵌入的维度(必备)
  332. nhead: 多头注意力中平行头的数目(必备)
  333. dim_feedforward: 全连接层的神经元的数目,又称经过此层输入的维度(Default = 2048)
  334. dropout: dropout的概率(Default = 0.1)
  335. activation: 两个线性层中间的激活函数,默认relu或gelu
  336. lay_norm_eps: layer normalization中的微小量,防止分母为0(Default = 1e-5)
  337. batch_first: 若`True`,则为(batch, seq, feture),若为`False`,则为(seq, batch, feature)(Default:False)
  338. 例子:
  339. >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
  340. >>> memory = torch.randn((10, 32, 512))
  341. >>> tgt = torch.randn((20, 32, 512))
  342. >>> out = decoder_layer(tgt, memory)
  343. '''
  344. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu,
  345. layer_norm_eps=1e-5, batch_first=False) -> None:
  346. super(TransformerDecoderLayer, self).__init__()
  347. self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
  348. self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
  349. self.linear1 = nn.Linear(d_model, dim_feedforward)
  350. self.dropout = nn.Dropout(dropout)
  351. self.linear2 = nn.Linear(dim_feedforward, d_model)
  352. self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
  353. self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
  354. self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
  355. self.dropout1 = nn.Dropout(dropout)
  356. self.dropout2 = nn.Dropout(dropout)
  357. self.dropout3 = nn.Dropout(dropout)
  358. self.activation = activation
  359. def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
  360. memory_mask: Optional[Tensor] = None,tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  361. r'''
  362. 参数:
  363. tgt: 目标语言序列(必备)
  364. memory: 从最后一个encoder_layer跑出的句子(必备)
  365. tgt_mask: 目标语言序列的mask(可选)
  366. memory_mask(可选)
  367. tgt_key_padding_mask(可选)
  368. memory_key_padding_mask(可选)
  369. '''
  370. tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
  371. key_padding_mask=tgt_key_padding_mask)[0]
  372. tgt = tgt + self.dropout1(tgt2)
  373. tgt = self.norm1(tgt)
  374. tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
  375. key_padding_mask=memory_key_padding_mask)[0]
  376. tgt = tgt + self.dropout2(tgt2)
  377. tgt = self.norm2(tgt)
  378. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
  379. tgt = tgt + self.dropout3(tgt2)
  380. tgt = self.norm3(tgt)
  381. return tgt
  382. # decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
  383. # memory = torch.randn((10, 32, 512))
  384. # tgt = torch.randn((20, 32, 512))
  385. # out = decoder_layer(tgt, memory)
  386. # print(out.shape)
  387. # torch.Size([20, 32, 512])
  388. class TransformerEncoder(nn.Module):
  389. r'''
  390. 参数:
  391. encoder_layer
  392. num_layers: encoder_layer的层数
  393. norm: 归一化的选择
  394. 例子:
  395. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
  396. >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
  397. >>> src = torch.randn((10, 32, 512))
  398. >>> out = transformer_encoder(src)
  399. '''
  400. def __init__(self, encoder_layer, num_layers, norm=None):
  401. super(TransformerEncoder, self).__init__()
  402. self.layer = encoder_layer
  403. self.num_layers = num_layers
  404. self.norm = norm
  405. def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  406. output = positional_encoding(src, src.shape[-1])
  407. for _ in range(self.num_layers):
  408. output = self.layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
  409. if self.norm is not None:
  410. output = self.norm(output)
  411. return output
  412. # encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
  413. # transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
  414. # src = torch.randn((10, 32, 512))
  415. # out = transformer_encoder(src)
  416. # print(out.shape)
  417. class TransformerDecoder(nn.Module):
  418. r'''
  419. 参数:
  420. decoder_layer(必备)
  421. num_layers: decoder_layer的层数(必备)
  422. norm: 归一化选择
  423. 例子:
  424. >>> decoder_layer =TransformerDecoderLayer(d_model=512, nhead=8)
  425. >>> transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6)
  426. >>> memory = torch.rand(10, 32, 512)
  427. >>> tgt = torch.rand(20, 32, 512)
  428. >>> out = transformer_decoder(tgt, memory)
  429. '''
  430. def __init__(self, decoder_layer, num_layers, norm=None):
  431. super(TransformerDecoder, self).__init__()
  432. self.layer = decoder_layer
  433. self.num_layers = num_layers
  434. self.norm = norm
  435. def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
  436. memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
  437. memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  438. output = tgt
  439. for _ in range(self.num_layers):
  440. output = self.layer(output, memory, tgt_mask=tgt_mask,
  441. memory_mask=memory_mask,
  442. tgt_key_padding_mask=tgt_key_padding_mask,
  443. memory_key_padding_mask=memory_key_padding_mask)
  444. if self.norm is not None:
  445. output = self.norm(output)
  446. return output
  447. # decoder_layer =TransformerDecoderLayer(d_model=512, nhead=8)
  448. # transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6)
  449. # memory = torch.rand(10, 32, 512)
  450. # tgt = torch.rand(20, 32, 512)
  451. # out = transformer_decoder(tgt, memory)
  452. # print(out.shape)
  453. # torch.Size([20, 32, 512])
  454. class Transformer(nn.Module):
  455. r'''
  456. 参数:
  457. d_model: 词嵌入的维度(必备)(Default=512)
  458. nhead: 多头注意力中平行头的数目(必备)(Default=8)
  459. num_encoder_layers:编码层层数(Default=8)
  460. num_decoder_layers:解码层层数(Default=8)
  461. dim_feedforward: 全连接层的神经元的数目,又称经过此层输入的维度(Default = 2048)
  462. dropout: dropout的概率(Default = 0.1)
  463. activation: 两个线性层中间的激活函数,默认relu或gelu
  464. custom_encoder: 自定义encoder(Default=None)
  465. custom_decoder: 自定义decoder(Default=None)
  466. lay_norm_eps: layer normalization中的微小量,防止分母为0(Default = 1e-5)
  467. batch_first: 若`True`,则为(batch, seq, feture),若为`False`,则为(seq, batch, feature)(Default:False)
  468. 例子:
  469. >>> transformer_model = Transformer(nhead=16, num_encoder_layers=12)
  470. >>> src = torch.rand((10, 32, 512))
  471. >>> tgt = torch.rand((20, 32, 512))
  472. >>> out = transformer_model(src, tgt)
  473. '''
  474. def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
  475. num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
  476. activation = F.relu, custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
  477. layer_norm_eps: float = 1e-5, batch_first: bool = False) -> None:
  478. super(Transformer, self).__init__()
  479. if custom_encoder is not None:
  480. self.encoder = custom_encoder
  481. else:
  482. encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
  483. activation, layer_norm_eps, batch_first)
  484. encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
  485. self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers)
  486. if custom_decoder is not None:
  487. self.decoder = custom_decoder
  488. else:
  489. decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
  490. activation, layer_norm_eps, batch_first)
  491. decoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
  492. self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
  493. self._reset_parameters()
  494. self.d_model = d_model
  495. self.nhead = nhead
  496. self.batch_first = batch_first
  497. def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
  498. memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
  499. tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
  500. r'''
  501. 参数:
  502. src: 源语言序列(送入Encoder)(必备)
  503. tgt: 目标语言序列(送入Decoder)(必备)
  504. src_mask: (可选)
  505. tgt_mask: (可选)
  506. memory_mask: (可选)
  507. src_key_padding_mask: (可选)
  508. tgt_key_padding_mask: (可选)
  509. memory_key_padding_mask: (可选)
  510. 形状:
  511. - src: shape:`(S, N, E)`, `(N, S, E)` if batch_first.
  512. - tgt: shape:`(T, N, E)`, `(N, T, E)` if batch_first.
  513. - src_mask: shape:`(S, S)`.
  514. - tgt_mask: shape:`(T, T)`.
  515. - memory_mask: shape:`(T, S)`.
  516. - src_key_padding_mask: shape:`(N, S)`.
  517. - tgt_key_padding_mask: shape:`(N, T)`.
  518. - memory_key_padding_mask: shape:`(N, S)`.
  519. [src/tgt/memory]_mask确保有些位置不被看到,如做decode的时候,只能看该位置及其以前的,而不能看后面的。
  520. 若为ByteTensor,非0的位置会被忽略不做注意力;若为BoolTensor,True对应的位置会被忽略;
  521. 若为数值,则会直接加到attn_weights
  522. [src/tgt/memory]_key_padding_mask 使得key里面的某些元素不参与attention计算,三种情况同上
  523. - output: shape:`(T, N, E)`, `(N, T, E)` if batch_first.
  524. 注意:
  525. src和tgt的最后一维需要等于d_model,batch的那一维需要相等
  526. 例子:
  527. >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
  528. '''
  529. memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
  530. output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
  531. tgt_key_padding_mask=tgt_key_padding_mask,
  532. memory_key_padding_mask=memory_key_padding_mask)
  533. return output
  534. def generate_square_subsequent_mask(self, sz: int) -> Tensor:
  535. r'''产生关于序列的mask,被遮住的区域赋值`-inf`,未被遮住的区域赋值为`0`'''
  536. mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
  537. mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
  538. return mask
  539. def _reset_parameters(self):
  540. r'''用正态分布初始化参数'''
  541. for p in self.parameters():
  542. if p.dim() > 1:
  543. xavier_uniform_(p)
  544. transformer_model = Transformer(nhead=16, num_encoder_layers=12)
  545. src = torch.rand((10, 32, 512))
  546. tgt = torch.rand((20, 32, 512))
  547. out = transformer_model(src, tgt)
  548. print(out.shape)
  549. # torch.Size([20, 32, 512])
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...