Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

feature_pyramid_network.py 8.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
  1. from collections import OrderedDict
  2. from typing import Callable, Dict, List, Optional, Tuple
  3. import torch.nn.functional as F
  4. from torch import nn, Tensor
  5. from ..ops.misc import Conv2dNormActivation
  6. from ..utils import _log_api_usage_once
  7. class ExtraFPNBlock(nn.Module):
  8. """
  9. Base class for the extra block in the FPN.
  10. Args:
  11. results (List[Tensor]): the result of the FPN
  12. x (List[Tensor]): the original feature maps
  13. names (List[str]): the names for each one of the
  14. original feature maps
  15. Returns:
  16. results (List[Tensor]): the extended set of results
  17. of the FPN
  18. names (List[str]): the extended set of names for the results
  19. """
  20. def forward(
  21. self,
  22. results: List[Tensor],
  23. x: List[Tensor],
  24. names: List[str],
  25. ) -> Tuple[List[Tensor], List[str]]:
  26. pass
  27. class FeaturePyramidNetwork(nn.Module):
  28. """
  29. Module that adds a FPN from on top of a set of feature maps. This is based on
  30. `"Feature Pyramid Network for Object Detection" <https://arxiv.org/abs/1612.03144>`_.
  31. The feature maps are currently supposed to be in increasing depth
  32. order.
  33. The input to the model is expected to be an OrderedDict[Tensor], containing
  34. the feature maps on top of which the FPN will be added.
  35. Args:
  36. in_channels_list (list[int]): number of channels for each feature map that
  37. is passed to the module
  38. out_channels (int): number of channels of the FPN representation
  39. extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
  40. be performed. It is expected to take the fpn features, the original
  41. features and the names of the original features as input, and returns
  42. a new list of feature maps and their corresponding names
  43. norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
  44. Examples::
  45. >>> m = torchvision.ops.FeaturePyramidNetwork([10, 20, 30], 5)
  46. >>> # get some dummy data
  47. >>> x = OrderedDict()
  48. >>> x['feat0'] = torch.rand(1, 10, 64, 64)
  49. >>> x['feat2'] = torch.rand(1, 20, 16, 16)
  50. >>> x['feat3'] = torch.rand(1, 30, 8, 8)
  51. >>> # compute the FPN on top of x
  52. >>> output = m(x)
  53. >>> print([(k, v.shape) for k, v in output.items()])
  54. >>> # returns
  55. >>> [('feat0', torch.Size([1, 5, 64, 64])),
  56. >>> ('feat2', torch.Size([1, 5, 16, 16])),
  57. >>> ('feat3', torch.Size([1, 5, 8, 8]))]
  58. """
  59. _version = 2
  60. def __init__(
  61. self,
  62. in_channels_list: List[int],
  63. out_channels: int,
  64. extra_blocks: Optional[ExtraFPNBlock] = None,
  65. norm_layer: Optional[Callable[..., nn.Module]] = None,
  66. ):
  67. super().__init__()
  68. _log_api_usage_once(self)
  69. self.inner_blocks = nn.ModuleList()
  70. self.layer_blocks = nn.ModuleList()
  71. for in_channels in in_channels_list:
  72. if in_channels == 0:
  73. raise ValueError("in_channels=0 is currently not supported")
  74. inner_block_module = Conv2dNormActivation(
  75. in_channels, out_channels, kernel_size=1, padding=0, norm_layer=norm_layer, activation_layer=None
  76. )
  77. layer_block_module = Conv2dNormActivation(
  78. out_channels, out_channels, kernel_size=3, norm_layer=norm_layer, activation_layer=None
  79. )
  80. self.inner_blocks.append(inner_block_module)
  81. self.layer_blocks.append(layer_block_module)
  82. # initialize parameters now to avoid modifying the initialization of top_blocks
  83. for m in self.modules():
  84. if isinstance(m, nn.Conv2d):
  85. nn.init.kaiming_uniform_(m.weight, a=1)
  86. if m.bias is not None:
  87. nn.init.constant_(m.bias, 0)
  88. if extra_blocks is not None:
  89. if not isinstance(extra_blocks, ExtraFPNBlock):
  90. raise TypeError(f"extra_blocks should be of type ExtraFPNBlock not {type(extra_blocks)}")
  91. self.extra_blocks = extra_blocks
  92. def _load_from_state_dict(
  93. self,
  94. state_dict,
  95. prefix,
  96. local_metadata,
  97. strict,
  98. missing_keys,
  99. unexpected_keys,
  100. error_msgs,
  101. ):
  102. version = local_metadata.get("version", None)
  103. if version is None or version < 2:
  104. num_blocks = len(self.inner_blocks)
  105. for block in ["inner_blocks", "layer_blocks"]:
  106. for i in range(num_blocks):
  107. for type in ["weight", "bias"]:
  108. old_key = f"{prefix}{block}.{i}.{type}"
  109. new_key = f"{prefix}{block}.{i}.0.{type}"
  110. if old_key in state_dict:
  111. state_dict[new_key] = state_dict.pop(old_key)
  112. super()._load_from_state_dict(
  113. state_dict,
  114. prefix,
  115. local_metadata,
  116. strict,
  117. missing_keys,
  118. unexpected_keys,
  119. error_msgs,
  120. )
  121. def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
  122. """
  123. This is equivalent to self.inner_blocks[idx](x),
  124. but torchscript doesn't support this yet
  125. """
  126. num_blocks = len(self.inner_blocks)
  127. if idx < 0:
  128. idx += num_blocks
  129. out = x
  130. for i, module in enumerate(self.inner_blocks):
  131. if i == idx:
  132. out = module(x)
  133. return out
  134. def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
  135. """
  136. This is equivalent to self.layer_blocks[idx](x),
  137. but torchscript doesn't support this yet
  138. """
  139. num_blocks = len(self.layer_blocks)
  140. if idx < 0:
  141. idx += num_blocks
  142. out = x
  143. for i, module in enumerate(self.layer_blocks):
  144. if i == idx:
  145. out = module(x)
  146. return out
  147. def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
  148. """
  149. Computes the FPN for a set of feature maps.
  150. Args:
  151. x (OrderedDict[Tensor]): feature maps for each feature level.
  152. Returns:
  153. results (OrderedDict[Tensor]): feature maps after FPN layers.
  154. They are ordered from the highest resolution first.
  155. """
  156. # unpack OrderedDict into two lists for easier handling
  157. names = list(x.keys())
  158. x = list(x.values())
  159. last_inner = self.get_result_from_inner_blocks(x[-1], -1)
  160. results = []
  161. results.append(self.get_result_from_layer_blocks(last_inner, -1))
  162. for idx in range(len(x) - 2, -1, -1):
  163. inner_lateral = self.get_result_from_inner_blocks(x[idx], idx)
  164. feat_shape = inner_lateral.shape[-2:]
  165. inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest")
  166. last_inner = inner_lateral + inner_top_down
  167. results.insert(0, self.get_result_from_layer_blocks(last_inner, idx))
  168. if self.extra_blocks is not None:
  169. results, names = self.extra_blocks(results, x, names)
  170. # make it back an OrderedDict
  171. out = OrderedDict([(k, v) for k, v in zip(names, results)])
  172. return out
  173. class LastLevelMaxPool(ExtraFPNBlock):
  174. """
  175. Applies a max_pool2d (not actual max_pool2d, we just subsample) on top of the last feature map
  176. """
  177. def forward(
  178. self,
  179. x: List[Tensor],
  180. y: List[Tensor],
  181. names: List[str],
  182. ) -> Tuple[List[Tensor], List[str]]:
  183. names.append("pool")
  184. # Use max pooling to simulate stride 2 subsampling
  185. x.append(F.max_pool2d(x[-1], kernel_size=1, stride=2, padding=0))
  186. return x, names
  187. class LastLevelP6P7(ExtraFPNBlock):
  188. """
  189. This module is used in RetinaNet to generate extra layers, P6 and P7.
  190. """
  191. def __init__(self, in_channels: int, out_channels: int):
  192. super().__init__()
  193. self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
  194. self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
  195. for module in [self.p6, self.p7]:
  196. nn.init.kaiming_uniform_(module.weight, a=1)
  197. nn.init.constant_(module.bias, 0)
  198. self.use_P5 = in_channels == out_channels
  199. def forward(
  200. self,
  201. p: List[Tensor],
  202. c: List[Tensor],
  203. names: List[str],
  204. ) -> Tuple[List[Tensor], List[str]]:
  205. p5, c5 = p[-1], c[-1]
  206. x = p5 if self.use_P5 else c5
  207. p6 = self.p6(x)
  208. p7 = self.p7(F.relu(p6))
  209. p.extend([p6, p7])
  210. names.extend(["p6", "p7"])
  211. return p, names
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...