Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

att.py 2.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4. class Attention(nn.Module):
  5. # target is hidden_size
  6. def __init__(self, hidden_size, method='concat'):
  7. super(Attention, self).__init__()
  8. self.method = method
  9. if self.method not in ('dot', 'general', 'concat'):
  10. raise NotImplemented
  11. if self.method == 'general':
  12. self.attn = nn.Linear(hidden_size, hidden_size)
  13. elif self.method == 'concat':
  14. self.attn = nn.Linear(2 * hidden_size, hidden_size)
  15. self.v = nn.Linear(hidden_size, 1, bias=False)
  16. self.init_weights()
  17. def init_weights(self):
  18. if hasattr(self, 'attn'):
  19. nn.init.xavier_normal_(self.attn.weight)
  20. nn.init.constant_(self.attn.bias, 0)
  21. if hasattr(self, 'v'):
  22. nn.init.xavier_normal_(self.v.weight)
  23. def dot_score(self, hidden, encoder_output):
  24. return torch.matmul(hidden, encoder_output)
  25. def general_score(self, hidden, encoder_output):
  26. attn = self.attn(encoder_output)
  27. return torch.matmul(hidden, attn)
  28. def concat_score(self, hidden, encoder_output):
  29. hidden_reshape = torch.unsqueeze(hidden, dim=0).repeat(encoder_output.size(0), 1, 1)
  30. attn = self.attn(torch.cat([hidden_reshape, encoder_output], dim=-1)).tanh()
  31. return self.v(attn).squeeze(dim=-1)
  32. def forward(self, hidden, encoder_output):
  33. # output = [batch_size x length x hidden_size]
  34. # hidden = [batch_size x hidden_size]
  35. attn_scores = None
  36. if self.method == 'dot':
  37. attn_scores = self.dot_score(hidden, encoder_output)
  38. elif self.method == 'general':
  39. attn_scores = self.general_score(hidden, encoder_output)
  40. elif self.method == 'concat':
  41. attn_scores = self.concat_score(hidden, encoder_output)
  42. # [lengths x batch_size] -> [batch_size x lengths]
  43. attn_scores = attn_scores.t()
  44. # return [batch_size x 1 x lengths]
  45. return F.softmax(attn_scores, dim=-1).unsqueeze(1)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...