Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

cnn.py 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
  1. import yaml
  2. import torch
  3. from collections import OrderedDict
  4. from torch import nn
  5. from model.base import ModelBase
  6. with open('params.yaml', 'r') as f:
  7. PARAMS = yaml.safe_load(f)
  8. if torch.cuda.is_available():
  9. DEVICE = torch.device('cuda', PARAMS.get('gpu', 0))
  10. else:
  11. DEVICE = torch.device('cpu')
  12. class Model(ModelBase):
  13. def __init__(self, vocab_size, embed_dim, hidden_size, kernel_size, n_layers, dropout, num_classes,
  14. padding_idx, *args, **kwargs):
  15. super(Model, self).__init__()
  16. self.hidden_size = hidden_size
  17. self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
  18. layers = [
  19. ('conv1', nn.Conv1d(embed_dim, hidden_size, kernel_size)),
  20. ('drop1', nn.Dropout(dropout)),
  21. ('mp1', nn.MaxPool1d(kernel_size)),
  22. ('relu1', nn.ReLU())
  23. ]
  24. for i in range(2, n_layers+1):
  25. layers += [
  26. (f'conv{i}', nn.Conv1d(hidden_size, hidden_size, kernel_size)),
  27. (f'drop{i}', nn.Dropout(dropout)),
  28. (f'mp{i}', nn.MaxPool1d(kernel_size)),
  29. (f'relu{i}', nn.ReLU())
  30. ]
  31. self.conv = nn.Sequential(OrderedDict(layers))
  32. self.out = nn.Linear(hidden_size, 1)
  33. self.init_weights()
  34. def init_weights(self):
  35. nn.init.xavier_normal_(self.embedding.weight)
  36. for m in self.conv.modules():
  37. if isinstance(m, nn.Conv1d):
  38. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  39. nn.init.constant_(m.bias, 0)
  40. nn.init.xavier_normal_(self.out.weight)
  41. nn.init.constant_(self.out.bias, 0)
  42. def forward(self, text, text_lengths, hidden=None):
  43. # text = [L x B]
  44. emb = self.embedding(text)
  45. # emb = [L x B x D] -> [B x D x L]
  46. emb = emb.permute(1, 2, 0)
  47. x = self.conv(emb)
  48. x, _ = torch.max(x, dim=-1)
  49. x = self.out(x).sigmoid()
  50. return x
  51. def load_model(self, model_path):
  52. self.load_state_dict(torch.load(model_path))
  53. self.eval()
  54. def save_model(self, model_path):
  55. torch.save(self.state_dict(), model_path)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...