1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
|
- <!DOCTYPE html>
- <html class="writer-html5" lang="en" >
- <head>
- <meta charset="utf-8" />
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
- <title>super_gradients.training.metrics.detection_metrics — SuperGradients 1.0 documentation</title>
- <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
- <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
- <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
- <!--[if lt IE 9]>
- <script src="../../../../_static/js/html5shiv.min.js"></script>
- <![endif]-->
-
- <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
- <script src="../../../../_static/jquery.js"></script>
- <script src="../../../../_static/underscore.js"></script>
- <script src="../../../../_static/doctools.js"></script>
- <script src="../../../../_static/js/theme.js"></script>
- <link rel="index" title="Index" href="../../../../genindex.html" />
- <link rel="search" title="Search" href="../../../../search.html" />
- </head>
- <body class="wy-body-for-nav">
- <div class="wy-grid-for-nav">
- <nav data-toggle="wy-nav-shift" class="wy-nav-side">
- <div class="wy-side-scroll">
- <div class="wy-side-nav-search" >
- <a href="../../../../index.html" class="icon icon-home"> SuperGradients
- </a>
- <div role="search">
- <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
- <input type="text" name="q" placeholder="Search docs" />
- <input type="hidden" name="check_keywords" value="yes" />
- <input type="hidden" name="area" value="default" />
- </form>
- </div>
- </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
- <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
- <ul>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -> Fill Survey</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
- </ul>
- <p class="caption"><span class="caption-text">Technical Documentation</span></p>
- <ul>
- <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
- </ul>
- <p class="caption"><span class="caption-text">User Guide</span></p>
- <ul>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
- </ul>
- </div>
- </div>
- </nav>
- <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
- <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
- <a href="../../../../index.html">SuperGradients</a>
- </nav>
- <div class="wy-nav-content">
- <div class="rst-content">
- <div role="navigation" aria-label="Page navigation">
- <ul class="wy-breadcrumbs">
- <li><a href="../../../../index.html" class="icon icon-home"></a> »</li>
- <li><a href="../../../index.html">Module code</a> »</li>
- <li>super_gradients.training.metrics.detection_metrics</li>
- <li class="wy-breadcrumbs-aside">
- </li>
- </ul>
- <hr/>
- </div>
- <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
- <div itemprop="articleBody">
-
- <h1>Source code for super_gradients.training.metrics.detection_metrics</h1><div class="highlight"><pre>
- <span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span>
- <span class="kn">import</span> <span class="nn">torch</span>
- <span class="kn">from</span> <span class="nn">torchmetrics</span> <span class="kn">import</span> <span class="n">Metric</span>
- <span class="kn">import</span> <span class="nn">super_gradients</span>
- <span class="kn">from</span> <span class="nn">super_gradients.training.utils</span> <span class="kn">import</span> <span class="n">tensor_container_to_device</span>
- <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">compute_detection_matching</span><span class="p">,</span> <span class="n">compute_detection_metrics</span>
- <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">DetectionPostPredictionCallback</span><span class="p">,</span> <span class="n">IouThreshold</span>
- <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
- <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
- <div class="viewcode-block" id="DetectionMetrics"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.DetectionMetrics">[docs]</a><span class="k">class</span> <span class="nc">DetectionMetrics</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> DetectionMetrics</span>
- <span class="sd"> Metric class for computing F1, Precision, Recall and Mean Average Precision.</span>
- <span class="sd"> Attributes:</span>
- <span class="sd"> num_cls: Number of classes.</span>
- <span class="sd"> post_prediction_callback: DetectionPostPredictionCallback to be applied on net's output prior</span>
- <span class="sd"> to the metric computation (NMS).</span>
- <span class="sd"> normalize_targets: Whether to normalize bbox coordinates by image size (default=False).</span>
- <span class="sd"> iou_thresholds: IoU threshold to compute the mAP (default=torch.linspace(0.5, 0.95, 10)).</span>
- <span class="sd"> recall_thresholds: Recall threshold to compute the mAP (default=torch.linspace(0, 1, 101)).</span>
- <span class="sd"> score_threshold: Score threshold to compute Recall, Precision and F1 (default=0.1)</span>
- <span class="sd"> top_k_predictions: Number of predictions per class used to compute metrics, ordered by confidence score</span>
- <span class="sd"> (default=100)</span>
- <span class="sd"> dist_sync_on_step: Synchronize metric state across processes at each ``forward()``</span>
- <span class="sd"> before returning the value at the step. (default=False)</span>
- <span class="sd"> accumulate_on_cpu: Run on CPU regardless of device used in other parts.</span>
- <span class="sd"> This is to avoid "CUDA out of memory" that might happen on GPU (default False)</span>
- <span class="sd"> """</span>
- <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_cls</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
- <span class="n">post_prediction_callback</span><span class="p">:</span> <span class="n">DetectionPostPredictionCallback</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
- <span class="n">normalize_targets</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
- <span class="n">iou_thres</span><span class="p">:</span> <span class="n">IouThreshold</span> <span class="o">=</span> <span class="n">IouThreshold</span><span class="o">.</span><span class="n">MAP_05_TO_095</span><span class="p">,</span>
- <span class="n">recall_thres</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
- <span class="n">score_thres</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
- <span class="n">top_k_predictions</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
- <span class="n">dist_sync_on_step</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
- <span class="n">accumulate_on_cpu</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
- <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">dist_sync_on_step</span><span class="o">=</span><span class="n">dist_sync_on_step</span><span class="p">)</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">num_cls</span> <span class="o">=</span> <span class="n">num_cls</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">iou_thres</span> <span class="o">=</span> <span class="n">iou_thres</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">map_str</span> <span class="o">=</span> <span class="s1">'mAP@</span><span class="si">%.1f</span><span class="s1">'</span> <span class="o">%</span> <span class="n">iou_thres</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">iou_thres</span><span class="o">.</span><span class="n">is_range</span><span class="p">()</span> <span class="k">else</span> <span class="s1">'mAP@</span><span class="si">%.2f</span><span class="s1">:</span><span class="si">%.2f</span><span class="s1">'</span> <span class="o">%</span> <span class="n">iou_thres</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">component_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"Precision"</span><span class="p">,</span> <span class="s2">"Recall"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">map_str</span><span class="p">,</span> <span class="s2">"F1"</span><span class="p">]</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">components</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">component_names</span><span class="p">)</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span> <span class="o">=</span> <span class="n">post_prediction_callback</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span> <span class="o">=</span> <span class="n">super_gradients</span><span class="o">.</span><span class="n">is_distributed</span><span class="p">()</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">denormalize_targets</span> <span class="o">=</span> <span class="ow">not</span> <span class="n">normalize_targets</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">=</span> <span class="kc">None</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">=</span> <span class="kc">None</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">add_state</span><span class="p">(</span><span class="s2">"matching_info"</span><span class="p">,</span> <span class="n">default</span><span class="o">=</span><span class="p">[],</span> <span class="n">dist_reduce_fx</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span> <span class="o">=</span> <span class="n">iou_thres</span><span class="o">.</span><span class="n">to_tensor</span><span class="p">()</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">recall_thresholds</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">101</span><span class="p">)</span> <span class="k">if</span> <span class="n">recall_thres</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">recall_thres</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">score_threshold</span> <span class="o">=</span> <span class="n">score_thres</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">top_k_predictions</span> <span class="o">=</span> <span class="n">top_k_predictions</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span> <span class="o">=</span> <span class="n">accumulate_on_cpu</span>
- <div class="viewcode-block" id="DetectionMetrics.update"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.DetectionMetrics.update">[docs]</a> <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
- <span class="n">inputs</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">,</span> <span class="n">crowd_targets</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> Apply NMS and match all the predictions and targets of a given batch, and update the metric state accordingly.</span>
- <span class="sd"> :param preds : Raw output of the model, the format might change from one model to another, but has to fit</span>
- <span class="sd"> the input format of the post_prediction_callback</span>
- <span class="sd"> :param target: Targets for all images of shape (total_num_targets, 6)</span>
- <span class="sd"> format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]</span>
- <span class="sd"> :param device: Device to run on</span>
- <span class="sd"> :param inputs: Input image tensor of shape (batch_size, n_img, height, width)</span>
- <span class="sd"> :param crowd_targets: Crowd targets for all images of shape (total_num_targets, 6)</span>
- <span class="sd"> format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]</span>
- <span class="sd"> """</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
- <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">shape</span>
- <span class="n">targets</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
- <span class="n">crowd_targets</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">6</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span> <span class="k">if</span> <span class="n">crowd_targets</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">crowd_targets</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
- <span class="n">preds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">post_prediction_callback</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span>
- <span class="n">new_matching_info</span> <span class="o">=</span> <span class="n">compute_detection_matching</span><span class="p">(</span>
- <span class="n">preds</span><span class="p">,</span> <span class="n">targets</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">iou_thresholds</span><span class="p">,</span> <span class="n">crowd_targets</span><span class="o">=</span><span class="n">crowd_targets</span><span class="p">,</span>
- <span class="n">top_k</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">top_k_predictions</span><span class="p">,</span> <span class="n">denormalize_targets</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">denormalize_targets</span><span class="p">,</span>
- <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">return_on_cpu</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span><span class="p">)</span>
- <span class="n">accumulated_matching_info</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">"matching_info"</span><span class="p">)</span>
- <span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">"matching_info"</span><span class="p">,</span> <span class="n">accumulated_matching_info</span> <span class="o">+</span> <span class="n">new_matching_info</span><span class="p">)</span></div>
- <div class="viewcode-block" id="DetectionMetrics.compute"><a class="viewcode-back" href="../../../../super_gradients.training.metrics.html#super_gradients.training.metrics.DetectionMetrics.compute">[docs]</a> <span class="k">def</span> <span class="nf">compute</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]]:</span>
- <span class="sd">"""Compute the metrics for all the accumulated results.</span>
- <span class="sd"> :return: Metrics of interest</span>
- <span class="sd"> """</span>
- <span class="n">mean_ap</span><span class="p">,</span> <span class="n">mean_precision</span><span class="p">,</span> <span class="n">mean_recall</span><span class="p">,</span> <span class="n">mean_f1</span> <span class="o">=</span> <span class="o">-</span><span class="mf">1.</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.</span>
- <span class="n">accumulated_matching_info</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">"matching_info"</span><span class="p">)</span>
- <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">accumulated_matching_info</span><span class="p">):</span>
- <span class="n">matching_info_tensors</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">accumulated_matching_info</span><span class="p">))]</span>
- <span class="c1"># shape (n_class, nb_iou_thresh)</span>
- <span class="n">ap</span><span class="p">,</span> <span class="n">precision</span><span class="p">,</span> <span class="n">recall</span><span class="p">,</span> <span class="n">f1</span><span class="p">,</span> <span class="n">unique_classes</span> <span class="o">=</span> <span class="n">compute_detection_metrics</span><span class="p">(</span>
- <span class="o">*</span><span class="n">matching_info_tensors</span><span class="p">,</span> <span class="n">recall_thresholds</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">recall_thresholds</span><span class="p">,</span> <span class="n">score_threshold</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">score_threshold</span><span class="p">,</span>
- <span class="n">device</span><span class="o">=</span><span class="s2">"cpu"</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
- <span class="c1"># Precision, recall and f1 are computed for smallest IoU threshold (usually 0.5), averaged over classes</span>
- <span class="n">mean_precision</span><span class="p">,</span> <span class="n">mean_recall</span><span class="p">,</span> <span class="n">mean_f1</span> <span class="o">=</span> <span class="n">precision</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">mean</span><span class="p">(),</span> <span class="n">recall</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">mean</span><span class="p">(),</span> <span class="n">f1</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
- <span class="c1"># MaP is averaged over IoU thresholds and over classes</span>
- <span class="n">mean_ap</span> <span class="o">=</span> <span class="n">ap</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
- <span class="k">return</span> <span class="p">{</span><span class="s2">"Precision"</span><span class="p">:</span> <span class="n">mean_precision</span><span class="p">,</span> <span class="s2">"Recall"</span><span class="p">:</span> <span class="n">mean_recall</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">map_str</span><span class="p">:</span> <span class="n">mean_ap</span><span class="p">,</span> <span class="s2">"F1"</span><span class="p">:</span> <span class="n">mean_f1</span><span class="p">}</span></div>
- <span class="k">def</span> <span class="nf">_sync_dist</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dist_sync_fn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">process_group</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> When in distributed mode, stats are aggregated after each forward pass to the metric state. Since these have all</span>
- <span class="sd"> different sizes we override the synchronization function since it works only for tensors (and use</span>
- <span class="sd"> all_gather_object)</span>
- <span class="sd"> @param dist_sync_fn:</span>
- <span class="sd"> @return:</span>
- <span class="sd"> """</span>
- <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">get_world_size</span><span class="p">()</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span> <span class="k">else</span> <span class="o">-</span><span class="mi">1</span>
- <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">get_rank</span><span class="p">()</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span> <span class="k">else</span> <span class="o">-</span><span class="mi">1</span>
- <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_distributed</span><span class="p">:</span>
- <span class="n">local_state_dict</span> <span class="o">=</span> <span class="p">{</span><span class="n">attr</span><span class="p">:</span> <span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attr</span><span class="p">)</span> <span class="k">for</span> <span class="n">attr</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_reductions</span><span class="o">.</span><span class="n">keys</span><span class="p">()}</span>
- <span class="n">gathered_state_dicts</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">world_size</span>
- <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">barrier</span><span class="p">()</span>
- <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">all_gather_object</span><span class="p">(</span><span class="n">gathered_state_dicts</span><span class="p">,</span> <span class="n">local_state_dict</span><span class="p">)</span>
- <span class="n">matching_info</span> <span class="o">=</span> <span class="p">[]</span>
- <span class="k">for</span> <span class="n">state_dict</span> <span class="ow">in</span> <span class="n">gathered_state_dicts</span><span class="p">:</span>
- <span class="n">matching_info</span> <span class="o">+=</span> <span class="n">state_dict</span><span class="p">[</span><span class="s2">"matching_info"</span><span class="p">]</span>
- <span class="n">matching_info</span> <span class="o">=</span> <span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">matching_info</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">"cpu"</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">accumulate_on_cpu</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
- <span class="nb">setattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="s2">"matching_info"</span><span class="p">,</span> <span class="n">matching_info</span><span class="p">)</span></div>
- </pre></div>
- </div>
- </div>
- <footer>
- <hr/>
- <div role="contentinfo">
- <p>© Copyright 2021, SuperGradients team.</p>
- </div>
- Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
- <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
- provided by <a href="https://readthedocs.org">Read the Docs</a>.
-
- </footer>
- </div>
- </div>
- </section>
- </div>
- <script>
- jQuery(function () {
- SphinxRtdTheme.Navigation.enable(true);
- });
- </script>
- </body>
- </html>
|