1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
|
- <!DOCTYPE html>
- <html class="writer-html5" lang="en" >
- <head>
- <meta charset="utf-8" />
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
- <title>super_gradients.training.losses.ssd_loss — SuperGradients 3.0.3 documentation</title>
- <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
- <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
- <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
- <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
- <!--[if lt IE 9]>
- <script src="../../../../_static/js/html5shiv.min.js"></script>
- <![endif]-->
-
- <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
- <script src="../../../../_static/jquery.js"></script>
- <script src="../../../../_static/underscore.js"></script>
- <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
- <script src="../../../../_static/doctools.js"></script>
- <script src="../../../../_static/sphinx_highlight.js"></script>
- <script src="../../../../_static/js/theme.js"></script>
- <link rel="index" title="Index" href="../../../../genindex.html" />
- <link rel="search" title="Search" href="../../../../search.html" />
- </head>
- <body class="wy-body-for-nav">
- <div class="wy-grid-for-nav">
- <nav data-toggle="wy-nav-shift" class="wy-nav-side">
- <div class="wy-side-scroll">
- <div class="wy-side-nav-search" >
- <a href="../../../../index.html" class="icon icon-home"> SuperGradients
- </a>
- <div role="search">
- <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
- <input type="text" name="q" placeholder="Search docs" />
- <input type="hidden" name="check_keywords" value="yes" />
- <input type="hidden" name="area" value="default" />
- </form>
- </div>
- </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
- <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
- <ul>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
- </ul>
- <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
- <ul>
- <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
- </ul>
- </div>
- </div>
- </nav>
- <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
- <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
- <a href="../../../../index.html">SuperGradients</a>
- </nav>
- <div class="wy-nav-content">
- <div class="rst-content">
- <div role="navigation" aria-label="Page navigation">
- <ul class="wy-breadcrumbs">
- <li><a href="../../../../index.html" class="icon icon-home"></a> »</li>
- <li><a href="../../../index.html">Module code</a> »</li>
- <li>super_gradients.training.losses.ssd_loss</li>
- <li class="wy-breadcrumbs-aside">
- </li>
- </ul>
- <hr/>
- </div>
- <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
- <div itemprop="articleBody">
-
- <h1>Source code for super_gradients.training.losses.ssd_loss</h1><div class="highlight"><pre>
- <span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Tuple</span>
- <span class="kn">import</span> <span class="nn">torch</span>
- <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
- <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
- <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">calculate_bbox_iou_matrix</span>
- <span class="kn">from</span> <span class="nn">super_gradients.training.utils.ssd_utils</span> <span class="kn">import</span> <span class="n">DefaultBoxes</span>
- <span class="k">class</span> <span class="nc">HardMiningCrossEntropyLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> L_cls = [CE of all positives] + [CE of the hardest backgrounds]</span>
- <span class="sd"> where the second term is built from [neg_pos_ratio * positive pairs] background cells with the highest CE</span>
- <span class="sd"> (the hardest background cells)</span>
- <span class="sd"> """</span>
- <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">neg_pos_ratio</span><span class="p">:</span> <span class="nb">float</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> :param neg_pos_ratio: a ratio of negative samples to positive samples in the loss</span>
- <span class="sd"> (unlike positives, not all negatives will be used:</span>
- <span class="sd"> for each positive the [neg_pos_ratio] hardest negatives will be selected)</span>
- <span class="sd"> """</span>
- <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">neg_pos_ratio</span> <span class="o">=</span> <span class="n">neg_pos_ratio</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">ce</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">(</span><span class="n">reduce</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
- <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">pred_labels</span><span class="p">,</span> <span class="n">target_labels</span><span class="p">):</span>
- <span class="n">mask</span> <span class="o">=</span> <span class="n">target_labels</span> <span class="o">></span> <span class="mi">0</span> <span class="c1"># not background</span>
- <span class="n">pos_num</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
- <span class="c1"># HARD NEGATIVE MINING</span>
- <span class="n">con</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ce</span><span class="p">(</span><span class="n">pred_labels</span><span class="p">,</span> <span class="n">target_labels</span><span class="p">)</span>
- <span class="c1"># POSITIVE MASK WILL NOT BE SELECTED</span>
- <span class="c1"># set 0. loss for all positive objects, leave the loss where the object is background</span>
- <span class="n">con_neg</span> <span class="o">=</span> <span class="n">con</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
- <span class="n">con_neg</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
- <span class="c1"># sort background cells by CE loss value (bigger_first)</span>
- <span class="n">_</span><span class="p">,</span> <span class="n">con_idx</span> <span class="o">=</span> <span class="n">con_neg</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
- <span class="c1"># restore cells order, get each cell's order (rank) in CE loss sorting</span>
- <span class="n">_</span><span class="p">,</span> <span class="n">con_rank</span> <span class="o">=</span> <span class="n">con_idx</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
- <span class="c1"># NUMBER OF NEGATIVE THREE TIMES POSITIVE</span>
- <span class="n">neg_num</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">neg_pos_ratio</span> <span class="o">*</span> <span class="n">pos_num</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">mask</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
- <span class="c1"># for each image into neg mask we'll take (3 * positive pairs) background objects with the highest CE</span>
- <span class="n">neg_mask</span> <span class="o">=</span> <span class="n">con_rank</span> <span class="o"><</span> <span class="n">neg_num</span>
- <span class="n">closs</span> <span class="o">=</span> <span class="p">(</span><span class="n">con</span> <span class="o">*</span> <span class="p">(</span><span class="n">mask</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">+</span> <span class="n">neg_mask</span><span class="o">.</span><span class="n">float</span><span class="p">()))</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
- <span class="k">return</span> <span class="n">closs</span>
- <div class="viewcode-block" id="SSDLoss"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.SSDLoss">[docs]</a><span class="k">class</span> <span class="nc">SSDLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> Implements the loss as the sum of the followings:</span>
- <span class="sd"> 1. Confidence Loss: All labels, with hard negative mining</span>
- <span class="sd"> 2. Localization Loss: Only on positive labels</span>
- <span class="sd"> L = (2 - alpha) * L_l1 + alpha * L_cls, where</span>
- <span class="sd"> * L_cls is HardMiningCrossEntropyLoss</span>
- <span class="sd"> * L_l1 = [SmoothL1Loss for all positives]</span>
- <span class="sd"> """</span>
- <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dboxes</span><span class="p">:</span> <span class="n">DefaultBoxes</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">iou_thresh</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">neg_pos_ratio</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">3.</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> :param dboxes: model anchors, shape [Num Grid Cells * Num anchors x 4]</span>
- <span class="sd"> :param alpha: a weighting factor between classification and regression loss</span>
- <span class="sd"> :param iou_thresh: a threshold for matching of anchors in each grid cell to GTs</span>
- <span class="sd"> (a match should have IoU > iou_thresh)</span>
- <span class="sd"> :param neg_pos_ratio: a ratio for HardMiningCrossEntropyLoss</span>
- <span class="sd"> """</span>
- <span class="nb">super</span><span class="p">(</span><span class="n">SSDLoss</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">scale_xy</span> <span class="o">=</span> <span class="n">dboxes</span><span class="o">.</span><span class="n">scale_xy</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">scale_wh</span> <span class="o">=</span> <span class="n">dboxes</span><span class="o">.</span><span class="n">scale_wh</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">dboxes</span><span class="p">(</span><span class="n">order</span><span class="o">=</span><span class="s2">"xywh"</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">sl1_loss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">SmoothL1Loss</span><span class="p">(</span><span class="n">reduce</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">con_loss</span> <span class="o">=</span> <span class="n">HardMiningCrossEntropyLoss</span><span class="p">(</span><span class="n">neg_pos_ratio</span><span class="p">)</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresh</span> <span class="o">=</span> <span class="n">iou_thresh</span>
- <span class="nd">@property</span>
- <span class="k">def</span> <span class="nf">component_names</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> Component names for logging during training.</span>
- <span class="sd"> These correspond to 2nd item in the tuple returned in self.forward(...).</span>
- <span class="sd"> See super_gradients.Trainer.train() docs for more info.</span>
- <span class="sd"> """</span>
- <span class="k">return</span> <span class="p">[</span><span class="s2">"smooth_l1"</span><span class="p">,</span> <span class="s2">"closs"</span><span class="p">,</span> <span class="s2">"Loss"</span><span class="p">]</span>
- <span class="k">def</span> <span class="nf">_norm_relative_bbox</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loc</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> convert bbox locations into relative locations (relative to the dboxes)</span>
- <span class="sd"> :param loc a tensor of shape [batch, 4, num_boxes]</span>
- <span class="sd"> """</span>
- <span class="n">gxy</span> <span class="o">=</span> <span class="p">((</span><span class="n">loc</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">,</span> <span class="p">:]</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">,</span> <span class="p">:])</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:,</span> <span class="p">])</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_xy</span>
- <span class="n">gwh</span> <span class="o">=</span> <span class="p">(</span><span class="n">loc</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:,</span> <span class="p">:]</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:,</span> <span class="p">:])</span><span class="o">.</span><span class="n">log</span><span class="p">()</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_wh</span>
- <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">gxy</span><span class="p">,</span> <span class="n">gwh</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
- <div class="viewcode-block" id="SSDLoss.match_dboxes"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.SSDLoss.match_dboxes">[docs]</a> <span class="k">def</span> <span class="nf">match_dboxes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> creates tensors with target boxes and labels for each dboxes, so with the same len as dboxes.</span>
- <span class="sd"> * Each GT is assigned with a grid cell with the highest IoU, this creates a pair for each GT and some cells;</span>
- <span class="sd"> * The rest of grid cells are assigned to a GT with the highest IoU, assuming it's > self.iou_thresh;</span>
- <span class="sd"> If this condition is not met the grid cell is marked as background</span>
- <span class="sd"> GT-wise: one to many</span>
- <span class="sd"> Grid-cell-wise: one to one</span>
- <span class="sd"> :param targets: a tensor containing the boxes for a single image;</span>
- <span class="sd"> shape [num_boxes, 6] (image_id, label, x, y, w, h)</span>
- <span class="sd"> :return: two tensors</span>
- <span class="sd"> boxes - shape of dboxes [4, num_dboxes] (x,y,w,h)</span>
- <span class="sd"> labels - sahpe [num_dboxes]</span>
- <span class="sd"> """</span>
- <span class="n">device</span> <span class="o">=</span> <span class="n">targets</span><span class="o">.</span><span class="n">device</span>
- <span class="n">each_cell_target_locations</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
- <span class="n">each_cell_target_labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
- <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
- <span class="n">target_boxes</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span>
- <span class="n">target_labels</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span>
- <span class="n">ious</span> <span class="o">=</span> <span class="n">calculate_bbox_iou_matrix</span><span class="p">(</span><span class="n">target_boxes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">x1y1x2y2</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
- <span class="c1"># one best GT for EACH cell (does not guarantee that all GTs will be used)</span>
- <span class="n">best_target_per_cell</span><span class="p">,</span> <span class="n">best_target_per_cell_index</span> <span class="o">=</span> <span class="n">ious</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
- <span class="c1"># one best grid cell (anchor in it) for EACH target</span>
- <span class="n">best_cell_per_target</span><span class="p">,</span> <span class="n">best_cell_per_target_index</span> <span class="o">=</span> <span class="n">ious</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
- <span class="c1"># make sure EACH target has a grid cell assigned</span>
- <span class="n">best_target_per_cell_index</span><span class="p">[</span><span class="n">best_cell_per_target_index</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
- <span class="c1"># 2. is higher than any IoU, so it is guaranteed to pass any IoU threshold</span>
- <span class="c1"># which ensures that the pairs selected for each target will be included in the mask below</span>
- <span class="c1"># while the threshold will only affect other grid cell anchors that aren't pre-assigned to any target</span>
- <span class="n">best_target_per_cell</span><span class="p">[</span><span class="n">best_cell_per_target_index</span><span class="p">]</span> <span class="o">=</span> <span class="mf">2.</span>
- <span class="n">mask</span> <span class="o">=</span> <span class="n">best_target_per_cell</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresh</span>
- <span class="n">each_cell_target_locations</span><span class="p">[:,</span> <span class="n">mask</span><span class="p">]</span> <span class="o">=</span> <span class="n">target_boxes</span><span class="p">[</span><span class="n">best_target_per_cell_index</span><span class="p">[</span><span class="n">mask</span><span class="p">]]</span><span class="o">.</span><span class="n">T</span>
- <span class="n">each_cell_target_labels</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span> <span class="o">=</span> <span class="n">target_labels</span><span class="p">[</span><span class="n">best_target_per_cell_index</span><span class="p">[</span><span class="n">mask</span><span class="p">]]</span> <span class="o">+</span> <span class="mi">1</span>
- <span class="k">return</span> <span class="n">each_cell_target_locations</span><span class="p">,</span> <span class="n">each_cell_target_labels</span></div>
- <div class="viewcode-block" id="SSDLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.losses.SSDLoss.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> Compute the loss</span>
- <span class="sd"> :param predictions - predictions tensor coming from the network,</span>
- <span class="sd"> tuple with shapes ([Batch Size, 4, num_dboxes], [Batch Size, num_classes + 1, num_dboxes])</span>
- <span class="sd"> were predictions have logprobs for background and other classes</span>
- <span class="sd"> :param targets - targets for the batch. [num targets, 6] (index in batch, label, x,y,w,h)</span>
- <span class="sd"> """</span>
- <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">predictions</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="nb">tuple</span><span class="p">):</span>
- <span class="c1"># Calculate loss in a validation mode</span>
- <span class="n">predictions</span> <span class="o">=</span> <span class="n">predictions</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
- <span class="n">batch_target_locations</span> <span class="o">=</span> <span class="p">[]</span>
- <span class="n">batch_target_labels</span> <span class="o">=</span> <span class="p">[]</span>
- <span class="p">(</span><span class="n">ploc</span><span class="p">,</span> <span class="n">plabel</span><span class="p">)</span> <span class="o">=</span> <span class="n">predictions</span>
- <span class="n">targets</span> <span class="o">=</span> <span class="n">targets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
- <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ploc</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
- <span class="n">target_locations</span><span class="p">,</span> <span class="n">target_labels</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">match_dboxes</span><span class="p">(</span><span class="n">targets</span><span class="p">[</span><span class="n">targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">i</span><span class="p">])</span>
- <span class="n">batch_target_locations</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">target_locations</span><span class="p">)</span>
- <span class="n">batch_target_labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">target_labels</span><span class="p">)</span>
- <span class="n">batch_target_locations</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">batch_target_locations</span><span class="p">)</span>
- <span class="n">batch_target_labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">batch_target_labels</span><span class="p">)</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">)</span>
- <span class="n">mask</span> <span class="o">=</span> <span class="n">batch_target_labels</span> <span class="o">></span> <span class="mi">0</span> <span class="c1"># not background</span>
- <span class="n">pos_num</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
- <span class="n">vec_gd</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_norm_relative_bbox</span><span class="p">(</span><span class="n">batch_target_locations</span><span class="p">)</span>
- <span class="c1"># SUM ON FOUR COORDINATES, AND MASK</span>
- <span class="n">sl1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sl1_loss</span><span class="p">(</span><span class="n">ploc</span><span class="p">,</span> <span class="n">vec_gd</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
- <span class="n">sl1</span> <span class="o">=</span> <span class="p">(</span><span class="n">mask</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">*</span> <span class="n">sl1</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
- <span class="n">closs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">con_loss</span><span class="p">(</span><span class="n">plabel</span><span class="p">,</span> <span class="n">batch_target_labels</span><span class="p">)</span>
- <span class="c1"># AVOID NO OBJECT DETECTED</span>
- <span class="n">total_loss</span> <span class="o">=</span> <span class="p">(</span><span class="mi">2</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span><span class="p">)</span> <span class="o">*</span> <span class="n">sl1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">*</span> <span class="n">closs</span>
- <span class="n">num_mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">pos_num</span> <span class="o">></span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="c1"># a mask with 0 for images that have no positive pairs at all</span>
- <span class="n">pos_num</span> <span class="o">=</span> <span class="n">pos_num</span><span class="o">.</span><span class="n">float</span><span class="p">()</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)</span>
- <span class="n">ret</span> <span class="o">=</span> <span class="p">(</span><span class="n">total_loss</span> <span class="o">*</span> <span class="n">num_mask</span> <span class="o">/</span> <span class="n">pos_num</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># normalize by the number of positive pairs</span>
- <span class="k">return</span> <span class="n">ret</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">sl1</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">closs</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">ret</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)))</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div></div>
- </pre></div>
- </div>
- </div>
- <footer>
- <hr/>
- <div role="contentinfo">
- <p>© Copyright 2021, SuperGradients team.</p>
- </div>
- Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
- <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
- provided by <a href="https://readthedocs.org">Read the Docs</a>.
-
- </footer>
- </div>
- </div>
- </section>
- </div>
- <script>
- jQuery(function () {
- SphinxRtdTheme.Navigation.enable(true);
- });
- </script>
- </body>
- </html>
|