Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

epa_seq2seq_model.py 4.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
  1. """
  2. EPA Sequence-to-Sequence Model
  3. see https://sladewinter.medium.com/video-frame-prediction-using-convlstm-network-in-pytorch-b5210a6ce582
  4. """
  5. import torch
  6. import torch.nn as nn
  7. # Original ConvLSTM cell as proposed by Shi et al.
  8. class ConvLSTMCell(nn.Module):
  9. def __init__(self, in_channels, out_channels,
  10. kernel_size, padding, frame_size, activation='tanh'):
  11. super(ConvLSTMCell, self).__init__()
  12. if activation == "tanh":
  13. self.activation = torch.tanh
  14. elif activation == "relu":
  15. self.activation = torch.relu
  16. # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
  17. self.conv = nn.Conv2d(
  18. in_channels=in_channels + out_channels,
  19. out_channels=4 * out_channels,
  20. kernel_size=kernel_size,
  21. padding=padding)
  22. # Initialize weights for Hadamard Products
  23. self.W_ci = nn.Parameter(torch.Tensor(out_channels, *frame_size))
  24. self.W_co = nn.Parameter(torch.Tensor(out_channels, *frame_size))
  25. self.W_cf = nn.Parameter(torch.Tensor(out_channels, *frame_size))
  26. def forward(self, X, H_prev, C_prev):
  27. # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
  28. conv_output = self.conv(torch.cat([X, H_prev], dim=1))
  29. # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
  30. i_conv, f_conv, C_conv, o_conv = torch.chunk(conv_output, chunks=4, dim=1)
  31. input_gate = torch.sigmoid(i_conv + self.W_ci * C_prev )
  32. forget_gate = torch.sigmoid(f_conv + self.W_cf * C_prev )
  33. # Current Cell output
  34. C = forget_gate*C_prev + input_gate * self.activation(C_conv)
  35. output_gate = torch.sigmoid(o_conv + self.W_co * C )
  36. # Current Hidden State
  37. H = output_gate * self.activation(C)
  38. return H, C
  39. class ConvLSTM(nn.Module):
  40. def __init__(self, in_channels, out_channels,
  41. kernel_size, padding, frame_size,
  42. activation='tanh', device='cuda'):
  43. super(ConvLSTM, self).__init__()
  44. self.device = device
  45. self.out_channels = out_channels
  46. # We will unroll this over time steps
  47. self.convLSTMcell = ConvLSTMCell(
  48. in_channels, out_channels,
  49. kernel_size, padding, frame_size, activation
  50. )
  51. def forward(self, X):
  52. # X is a frame sequence (batch_size, num_channels, seq_len, height, width)
  53. # Get the dimensions
  54. batch_size, _, seq_len, height, width = X.size()
  55. # Initialize output
  56. output = torch.zeros(batch_size, self.out_channels, seq_len,
  57. height, width, device=self.device)
  58. # Initialize Hidden State
  59. H = torch.zeros(batch_size, self.out_channels,
  60. height, width, device=self.device)
  61. # Initialize Cell Input
  62. C = torch.zeros(batch_size,self.out_channels,
  63. height, width, device=self.device)
  64. # Unroll over time steps
  65. for time_step in range(seq_len):
  66. H, C = self.convLSTMcell(X[:,:,time_step], H, C)
  67. output[:,:,time_step] = H
  68. return output
  69. class EpaSeq2Seq(nn.Module):
  70. def __init__(self, in_channels, out_channels, frame_size,
  71. num_kernels=64, kernel_size=(3, 3), padding=(1, 1),
  72. num_layers=2, activation='tanh', device='cuda'):
  73. super(EpaSeq2Seq, self).__init__()
  74. self.sequential = nn.Sequential()
  75. # Add First layer (Different in_channels than the rest)
  76. self.sequential.add_module(
  77. "convlstm1", ConvLSTM(
  78. in_channels=in_channels, out_channels=num_kernels,
  79. kernel_size=kernel_size, padding=padding,
  80. frame_size=frame_size, activation=activation, device=device)
  81. )
  82. self.sequential.add_module(
  83. "batchnorm1", nn.BatchNorm3d(num_features=num_kernels)
  84. )
  85. # Add rest of the layers
  86. for l in range(2, num_layers+1):
  87. self.sequential.add_module(
  88. f"convlstm{l}", ConvLSTM(
  89. in_channels=num_kernels, out_channels=num_kernels,
  90. kernel_size=kernel_size, padding=padding,
  91. frame_size=frame_size, activation=activation,
  92. device=device)
  93. )
  94. self.sequential.add_module(
  95. f"batchnorm{l}", nn.BatchNorm3d(num_features=num_kernels)
  96. )
  97. # Add Convolutional Layer to predict output frame
  98. self.conv = nn.Conv2d(
  99. in_channels=num_kernels, out_channels=out_channels,
  100. kernel_size=kernel_size, padding=padding)
  101. def forward(self, X):
  102. # Forward propagation through all the layers
  103. seq_output = self.sequential(X)
  104. _, _, seq_len, _, _ = seq_output.size()
  105. # Apply convolutional layer to each element of series
  106. return torch.stack([
  107. self.conv(seq_output[:, :, time_step])
  108. for time_step in range(seq_len)
  109. ], dim=2)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...