1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
|
- '''PNASNet in PyTorch.
- Paper: Progressive Neural Architecture Search
- https://github.com/kuangliu/pytorch-cifar/blob/master/models/pnasnet.py
- '''
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from super_gradients.training.models.sg_module import SgModule
- class SepConv(nn.Module):
- '''Separable Convolution.'''
- def __init__(self, in_planes, out_planes, kernel_size, stride):
- super(SepConv, self).__init__()
- self.conv1 = nn.Conv2d(in_planes, out_planes,
- kernel_size, stride,
- padding=(kernel_size - 1) // 2,
- bias=False, groups=in_planes)
- self.bn1 = nn.BatchNorm2d(out_planes)
- def forward(self, x):
- return self.bn1(self.conv1(x))
- class CellA(nn.Module):
- def __init__(self, in_planes, out_planes, stride=1):
- super(CellA, self).__init__()
- self.stride = stride
- self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
- if stride == 2:
- self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
- self.bn1 = nn.BatchNorm2d(out_planes)
- def forward(self, x):
- y1 = self.sep_conv1(x)
- y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
- if self.stride == 2:
- y2 = self.bn1(self.conv1(y2))
- return F.relu(y1 + y2)
- class CellB(nn.Module):
- def __init__(self, in_planes, out_planes, stride=1):
- super(CellB, self).__init__()
- self.stride = stride
- # Left branch
- self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
- self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride)
- # Right branch
- self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride)
- if stride == 2:
- self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
- self.bn1 = nn.BatchNorm2d(out_planes)
- # Reduce channels
- self.conv2 = nn.Conv2d(2 * out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
- self.bn2 = nn.BatchNorm2d(out_planes)
- def forward(self, x):
- # Left branch
- y1 = self.sep_conv1(x)
- y2 = self.sep_conv2(x)
- # Right branch
- y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
- if self.stride == 2:
- y3 = self.bn1(self.conv1(y3))
- y4 = self.sep_conv3(x)
- # Concat & reduce channels
- b1 = F.relu(y1 + y2)
- b2 = F.relu(y3 + y4)
- y = torch.cat([b1, b2], 1)
- return F.relu(self.bn2(self.conv2(y)))
- class PNASNet(SgModule):
- def __init__(self, cell_type, num_cells, num_planes):
- super(PNASNet, self).__init__()
- self.in_planes = num_planes
- self.cell_type = cell_type
- self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False)
- self.bn1 = nn.BatchNorm2d(num_planes)
- self.layer1 = self._make_layer(num_planes, num_cells=6)
- self.layer2 = self._downsample(num_planes * 2)
- self.layer3 = self._make_layer(num_planes * 2, num_cells=6)
- self.layer4 = self._downsample(num_planes * 4)
- self.layer5 = self._make_layer(num_planes * 4, num_cells=6)
- self.linear = nn.Linear(num_planes * 4, 10)
- def _make_layer(self, planes, num_cells):
- layers = []
- for _ in range(num_cells):
- layers.append(self.cell_type(self.in_planes, planes, stride=1))
- self.in_planes = planes
- return nn.Sequential(*layers)
- def _downsample(self, planes):
- layer = self.cell_type(self.in_planes, planes, stride=2)
- self.in_planes = planes
- return layer
- def forward(self, x):
- out = F.relu(self.bn1(self.conv1(x)))
- out = self.layer1(out)
- out = self.layer2(out)
- out = self.layer3(out)
- out = self.layer4(out)
- out = self.layer5(out)
- out = F.avg_pool2d(out, 8)
- out = self.linear(out.view(out.size(0), -1))
- return out
- def PNASNetA():
- return PNASNet(CellA, num_cells=6, num_planes=44)
- def PNASNetB():
- return PNASNet(CellB, num_cells=6, num_planes=32)
- def test():
- net = PNASNetB()
- x = torch.randn(1, 3, 32, 32)
- y = net(x)
- print(y)
- # test()
|