Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

simple_model.py 2.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
  1. import torch
  2. from torch import nn
  3. import torch.nn.functional as F
  4. class ConllModel(nn.Module):
  5. def __init__(self,max_seq_len,emb_dim,num_labels):
  6. super().__init__()
  7. self.num_labels = num_labels
  8. cnn_kwargs = {
  9. "in_channels":1,
  10. "out_channels":16,
  11. "kernel_size":(3,emb_dim),
  12. "padding":(1,0),
  13. }
  14. self.cnn = nn.Conv2d(**cnn_kwargs)
  15. h_out = self.compute_output_dim(
  16. h_in = max_seq_len,
  17. padding = cnn_kwargs["padding"][0],
  18. kernel_size = cnn_kwargs["kernel_size"][0],
  19. )
  20. w_out = self.compute_output_dim(
  21. h_in = emb_dim,
  22. padding = cnn_kwargs["padding"][1],
  23. kernel_size = cnn_kwargs["kernel_size"][1],
  24. )
  25. last_dim = h_out*w_out*cnn_kwargs["out_channels"]
  26. self.fc = nn.Linear(last_dim,num_labels*max_seq_len)
  27. def compute_output_dim(self,h_in,padding,kernel_size):
  28. """
  29. Given input dim, padding, kernel size
  30. compute output dim
  31. (assumes stride=1,dilation=1)
  32. """
  33. out = h_in
  34. out += 2*padding
  35. out += -kernel_size
  36. out += 1
  37. return out
  38. def forward(self,x_raw,apply_softmax=False,verbose=False):
  39. batch_len,max_seq_len,_ = x_raw.size()
  40. x = x_raw.unsqueeze(dim=1)
  41. x = self.cnn(x)
  42. x = F.relu(x)
  43. x = x.view(batch_len,-1)
  44. x = self.fc(x)
  45. if verbose is True:
  46. print(f"Batch len is {batch_len}")
  47. print(f"max_seq_len is {max_seq_len}")
  48. print(f"num labels is {self.num_labels}")
  49. print(f"Tensor shape is {x.size()}")
  50. if apply_softmax is True:
  51. y_pred = x.view(batch_len,max_seq_len,self.num_labels)
  52. y_pred = F.softmax(y_pred,dim=2)
  53. y_pred = torch.argmax(y_pred,dim=2)
  54. else:
  55. y_pred = x.view(batch_len*max_seq_len,-1)
  56. return y_pred
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...