1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
|
- """
- Googlenet code based on https://pytorch.org/vision/stable/_modules/torchvision/models/googlenet.html
- """
- from collections import namedtuple
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from collections import OrderedDict
- from super_gradients.training.models.sg_module import SgModule
- GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['log_', 'aux_logits2', 'aux_logits1'])
- class GoogLeNet(SgModule):
- def __init__(self, num_classes=1000, aux_logits=True, init_weights=True,
- backbone_mode=False, dropout=0.3):
- super(GoogLeNet, self).__init__()
- self.num_classes = num_classes
- self.backbone_mode = backbone_mode
- self.aux_logits = aux_logits
- self.dropout_p = dropout
- self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
- self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
- self.conv2 = BasicConv2d(64, 64, kernel_size=1)
- self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
- self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
- self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
- self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
- self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
- self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
- self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
- self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
- self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
- self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
- self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
- self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
- self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
- if aux_logits:
- self.aux1 = InceptionAux(512, num_classes)
- self.aux2 = InceptionAux(528, num_classes)
- else:
- self.aux1 = None
- self.aux2 = None
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
- if not self.backbone_mode:
- self.dropout = nn.Dropout(self.dropout_p)
- self.fc = nn.Linear(1024, num_classes)
- if init_weights:
- self._initialize_weights()
- def _initialize_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
- import scipy.stats as stats
- x = stats.truncnorm(-2, 2, scale=0.01)
- values = torch.as_tensor(x.rvs(m.weight.numel()), dtype=m.weight.dtype)
- values = values.view(m.weight.size())
- with torch.no_grad():
- m.weight.copy_(values)
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- def _forward(self, x):
- # N x 3 x 224 x 224
- x = self.conv1(x)
- # N x 64 x 112 x 112
- x = self.maxpool1(x)
- # N x 64 x 56 x 56
- x = self.conv2(x)
- # N x 64 x 56 x 56
- x = self.conv3(x)
- # N x 192 x 56 x 56
- x = self.maxpool2(x)
- # N x 192 x 28 x 28
- x = self.inception3a(x)
- # N x 256 x 28 x 28
- x = self.inception3b(x)
- # N x 480 x 28 x 28
- x = self.maxpool3(x)
- # N x 480 x 14 x 14
- x = self.inception4a(x)
- # N x 512 x 14 x 14
- aux1 = None
- if self.aux1 is not None and self.training:
- aux1 = self.aux1(x)
- x = self.inception4b(x)
- # N x 512 x 14 x 14
- x = self.inception4c(x)
- # N x 512 x 14 x 14
- x = self.inception4d(x)
- # N x 528 x 14 x 14
- aux2 = None
- if self.aux2 is not None and self.training:
- aux2 = self.aux2(x)
- x = self.inception4e(x)
- # N x 832 x 14 x 14
- x = self.maxpool4(x)
- # N x 832 x 7 x 7
- x = self.inception5a(x)
- # N x 832 x 7 x 7
- x = self.inception5b(x)
- # N x 1024 x 7 x 7
- x = self.avgpool(x)
- # N x 1024 x 1 x 1
- x = torch.flatten(x, 1)
- # N x 1024
- if not self.backbone_mode:
- x = self.dropout(x)
- x = self.fc(x)
- # N x num_classes
- return x, aux2, aux1
- def forward(self, x):
- x, aux1, aux2 = self._forward(x)
- if self.training and self.aux_logits:
- return GoogLeNetOutputs(x, aux2, aux1)
- else:
- return x
- def load_state_dict(self, state_dict, strict=True):
- """
- load_state_dict - Overloads the base method and calls it to load a modified dict for usage as a backbone
- :param state_dict: The state_dict to load
- :param strict: strict loading (see super() docs)
- """
- pretrained_model_weights_dict = state_dict.copy()
- if self.backbone_mode:
- # FIRST LET'S POP THE LAST TWO LAYERS - NO NEED TO LOAD THEIR VALUES SINCE THEY ARE IRRELEVANT AS A BACKBONE
- pretrained_model_weights_dict.popitem()
- pretrained_model_weights_dict.popitem()
- pretrained_backbone_weights_dict = OrderedDict()
- for layer_name, weights in pretrained_model_weights_dict.items():
- # GET THE LAYER NAME WITHOUT THE 'module.' PREFIX
- name_without_module_prefix = layer_name.split('module.')[1]
- # MAKE SURE THESE ARE NOT THE FINAL LAYERS
- pretrained_backbone_weights_dict[name_without_module_prefix] = weights
- c_temp = torch.nn.Linear(1024, self.num_classes)
- torch.nn.init.xavier_uniform(c_temp.weight)
- pretrained_backbone_weights_dict['fc.weight'] = c_temp.weight
- pretrained_backbone_weights_dict['fc.bias'] = c_temp.bias
- # RETURNING THE UNMODIFIED/MODIFIED STATE DICT DEPENDING ON THE backbone_mode VALUE
- super().load_state_dict(pretrained_backbone_weights_dict, strict)
- else:
- super().load_state_dict(pretrained_model_weights_dict, strict)
- class Inception(nn.Module):
- def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj,
- conv_block=None):
- super(Inception, self).__init__()
- if conv_block is None:
- conv_block = BasicConv2d
- self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)
- self.branch2 = nn.Sequential(
- conv_block(in_channels, ch3x3red, kernel_size=1),
- conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
- )
- self.branch3 = nn.Sequential(
- conv_block(in_channels, ch5x5red, kernel_size=1),
- conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1)
- )
- self.branch4 = nn.Sequential(
- nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
- conv_block(in_channels, pool_proj, kernel_size=1)
- )
- def _forward(self, x):
- branch1 = self.branch1(x)
- branch2 = self.branch2(x)
- branch3 = self.branch3(x)
- branch4 = self.branch4(x)
- outputs = [branch1, branch2, branch3, branch4]
- return outputs
- def forward(self, x):
- outputs = self._forward(x)
- return torch.cat(outputs, 1)
- class InceptionAux(nn.Module):
- def __init__(self, in_channels, num_classes, conv_block=None):
- super(InceptionAux, self).__init__()
- if conv_block is None:
- conv_block = BasicConv2d
- self.conv = conv_block(in_channels, 128, kernel_size=1)
- self.fc1 = nn.Linear(2048, 1024)
- self.fc2 = nn.Linear(1024, num_classes)
- def forward(self, x):
- # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
- x = F.adaptive_avg_pool2d(x, (4, 4))
- # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
- x = self.conv(x)
- # N x 128 x 4 x 4
- x = torch.flatten(x, 1)
- # N x 2048
- x = F.relu(self.fc1(x), inplace=True)
- # N x 1024
- x = F.dropout(x, 0.7, training=self.training)
- # N x 1024
- x = self.fc2(x)
- # N x 1000 (num_classes)
- return x
- class BasicConv2d(nn.Module):
- def __init__(self, in_channels, out_channels, **kwargs):
- super(BasicConv2d, self).__init__()
- self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
- self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
- self.relu = nn.ReLU()
- def forward(self, x):
- x = self.conv(x)
- x = self.bn(x)
- x = self.relu(x)
- return x
- class GoogleNetV1(GoogLeNet):
- def __init__(self, arch_params):
- super(GoogleNetV1, self).__init__(aux_logits=False, num_classes=arch_params.num_classes, dropout=arch_params.dropout)
|