Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#396 Trainer constructor cleanup

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-266_clean_trainer_ctor
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.utils &mdash; SuperGradients 1.0 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <!--[if lt IE 9]>
  11. <script src="../../../../_static/js/html5shiv.min.js"></script>
  12. <![endif]-->
  13. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  14. <script src="../../../../_static/jquery.js"></script>
  15. <script src="../../../../_static/underscore.js"></script>
  16. <script src="../../../../_static/doctools.js"></script>
  17. <script src="../../../../_static/js/theme.js"></script>
  18. <link rel="index" title="Index" href="../../../../genindex.html" />
  19. <link rel="search" title="Search" href="../../../../search.html" />
  20. </head>
  21. <body class="wy-body-for-nav">
  22. <div class="wy-grid-for-nav">
  23. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  24. <div class="wy-side-scroll">
  25. <div class="wy-side-nav-search" >
  26. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  27. </a>
  28. <div role="search">
  29. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  30. <input type="text" name="q" placeholder="Search docs" />
  31. <input type="hidden" name="check_keywords" value="yes" />
  32. <input type="hidden" name="area" value="default" />
  33. </form>
  34. </div>
  35. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  36. <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
  37. <ul>
  38. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -&gt; Fill Survey</a></li>
  39. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
  40. </ul>
  41. <p class="caption"><span class="caption-text">Technical Documentation</span></p>
  42. <ul>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  45. </ul>
  46. <p class="caption"><span class="caption-text">User Guide</span></p>
  47. <ul>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
  57. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
  58. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
  59. <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
  60. </ul>
  61. </div>
  62. </div>
  63. </nav>
  64. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  65. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  66. <a href="../../../../index.html">SuperGradients</a>
  67. </nav>
  68. <div class="wy-nav-content">
  69. <div class="rst-content">
  70. <div role="navigation" aria-label="Page navigation">
  71. <ul class="wy-breadcrumbs">
  72. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  73. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  74. <li>super_gradients.training.utils.utils</li>
  75. <li class="wy-breadcrumbs-aside">
  76. </li>
  77. </ul>
  78. <hr/>
  79. </div>
  80. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  81. <div itemprop="articleBody">
  82. <h1>Source code for super_gradients.training.utils.utils</h1><div class="highlight"><pre>
  83. <span></span><span class="kn">import</span> <span class="nn">math</span>
  84. <span class="kn">import</span> <span class="nn">time</span>
  85. <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">lru_cache</span>
  86. <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
  87. <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Mapping</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">List</span>
  88. <span class="kn">from</span> <span class="nn">zipfile</span> <span class="kn">import</span> <span class="n">ZipFile</span>
  89. <span class="kn">import</span> <span class="nn">os</span>
  90. <span class="kn">from</span> <span class="nn">jsonschema</span> <span class="kn">import</span> <span class="n">validate</span>
  91. <span class="kn">import</span> <span class="nn">tarfile</span>
  92. <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span><span class="p">,</span> <span class="n">ExifTags</span>
  93. <span class="kn">import</span> <span class="nn">torch</span>
  94. <span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
  95. <span class="c1"># These functions changed from torch 1.2 to torch 1.3</span>
  96. <span class="kn">import</span> <span class="nn">random</span>
  97. <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
  98. <span class="kn">from</span> <span class="nn">importlib</span> <span class="kn">import</span> <span class="n">import_module</span>
  99. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
  100. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
  101. <div class="viewcode-block" id="convert_to_tensor"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.convert_to_tensor">[docs]</a><span class="k">def</span> <span class="nf">convert_to_tensor</span><span class="p">(</span><span class="n">array</span><span class="p">):</span>
  102. <span class="sd">&quot;&quot;&quot;Converts numpy arrays and lists to Torch tensors before calculation losses</span>
  103. <span class="sd"> :param array: torch.tensor / Numpy array / List</span>
  104. <span class="sd"> &quot;&quot;&quot;</span>
  105. <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">FloatTensor</span><span class="p">(</span><span class="n">array</span><span class="p">)</span> <span class="k">if</span> <span class="nb">type</span><span class="p">(</span><span class="n">array</span><span class="p">)</span> <span class="o">!=</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span> <span class="k">else</span> <span class="n">array</span></div>
  106. <div class="viewcode-block" id="HpmStruct"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.HpmStruct">[docs]</a><span class="k">class</span> <span class="nc">HpmStruct</span><span class="p">:</span>
  107. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">entries</span><span class="p">):</span>
  108. <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">entries</span><span class="p">)</span>
  109. <span class="bp">self</span><span class="o">.</span><span class="n">schema</span> <span class="o">=</span> <span class="kc">None</span>
  110. <div class="viewcode-block" id="HpmStruct.set_schema"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.HpmStruct.set_schema">[docs]</a> <span class="k">def</span> <span class="nf">set_schema</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">schema</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
  111. <span class="bp">self</span><span class="o">.</span><span class="n">schema</span> <span class="o">=</span> <span class="n">schema</span></div>
  112. <div class="viewcode-block" id="HpmStruct.override"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.HpmStruct.override">[docs]</a> <span class="k">def</span> <span class="nf">override</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">entries</span><span class="p">):</span>
  113. <span class="n">recursive_override</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">,</span> <span class="n">entries</span><span class="p">)</span></div>
  114. <div class="viewcode-block" id="HpmStruct.to_dict"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.HpmStruct.to_dict">[docs]</a> <span class="k">def</span> <span class="nf">to_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  115. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span></div>
  116. <div class="viewcode-block" id="HpmStruct.validate"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.HpmStruct.validate">[docs]</a> <span class="k">def</span> <span class="nf">validate</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  117. <span class="sd">&quot;&quot;&quot;</span>
  118. <span class="sd"> Validate the current dict values according to the provided schema</span>
  119. <span class="sd"> :raises</span>
  120. <span class="sd"> `AttributeError` if schema was not set</span>
  121. <span class="sd"> `jsonschema.exceptions.ValidationError` if the instance is invalid</span>
  122. <span class="sd"> `jsonschema.exceptions.SchemaError` if the schema itselfis invalid</span>
  123. <span class="sd"> &quot;&quot;&quot;</span>
  124. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">schema</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  125. <span class="k">raise</span> <span class="ne">AttributeError</span><span class="p">(</span><span class="s1">&#39;schema was not set&#39;</span><span class="p">)</span>
  126. <span class="k">else</span><span class="p">:</span>
  127. <span class="n">validate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="vm">__dict__</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">schema</span><span class="p">)</span></div></div>
  128. <div class="viewcode-block" id="WrappedModel"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.WrappedModel">[docs]</a><span class="k">class</span> <span class="nc">WrappedModel</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
  129. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">module</span><span class="p">):</span>
  130. <span class="nb">super</span><span class="p">(</span><span class="n">WrappedModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
  131. <span class="bp">self</span><span class="o">.</span><span class="n">module</span> <span class="o">=</span> <span class="n">module</span> <span class="c1"># that I actually define.</span>
  132. <div class="viewcode-block" id="WrappedModel.forward"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.WrappedModel.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
  133. <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">module</span><span class="p">(</span><span class="n">x</span><span class="p">)</span></div></div>
  134. <div class="viewcode-block" id="Timer"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.Timer">[docs]</a><span class="k">class</span> <span class="nc">Timer</span><span class="p">:</span>
  135. <span class="sd">&quot;&quot;&quot;A class to measure time handling both GPU &amp; CPU processes</span>
  136. <span class="sd"> Returns time in milliseconds&quot;&quot;&quot;</span>
  137. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  138. <span class="sd">&quot;&quot;&quot;</span>
  139. <span class="sd"> :param device: str</span>
  140. <span class="sd"> &#39;cpu&#39;\&#39;cuda&#39;</span>
  141. <span class="sd"> &quot;&quot;&quot;</span>
  142. <span class="bp">self</span><span class="o">.</span><span class="n">on_gpu</span> <span class="o">=</span> <span class="p">(</span><span class="n">device</span> <span class="o">==</span> <span class="s1">&#39;cuda&#39;</span><span class="p">)</span>
  143. <span class="c1"># On GPU time is measured using cuda.events</span>
  144. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">on_gpu</span><span class="p">:</span>
  145. <span class="bp">self</span><span class="o">.</span><span class="n">starter</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Event</span><span class="p">(</span><span class="n">enable_timing</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  146. <span class="bp">self</span><span class="o">.</span><span class="n">ender</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">Event</span><span class="p">(</span><span class="n">enable_timing</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  147. <span class="c1"># On CPU time is measured using time</span>
  148. <span class="k">else</span><span class="p">:</span>
  149. <span class="bp">self</span><span class="o">.</span><span class="n">starter</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ender</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span>
  150. <div class="viewcode-block" id="Timer.start"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.Timer.start">[docs]</a> <span class="k">def</span> <span class="nf">start</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  151. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">on_gpu</span><span class="p">:</span>
  152. <span class="bp">self</span><span class="o">.</span><span class="n">starter</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
  153. <span class="k">else</span><span class="p">:</span>
  154. <span class="bp">self</span><span class="o">.</span><span class="n">starter</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span></div>
  155. <div class="viewcode-block" id="Timer.stop"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.Timer.stop">[docs]</a> <span class="k">def</span> <span class="nf">stop</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  156. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">on_gpu</span><span class="p">:</span>
  157. <span class="bp">self</span><span class="o">.</span><span class="n">ender</span><span class="o">.</span><span class="n">record</span><span class="p">()</span>
  158. <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
  159. <span class="n">timer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">starter</span><span class="o">.</span><span class="n">elapsed_time</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ender</span><span class="p">)</span>
  160. <span class="k">else</span><span class="p">:</span>
  161. <span class="c1"># Time measures in seconds -&gt; convert to milliseconds</span>
  162. <span class="n">timer</span> <span class="o">=</span> <span class="p">(</span><span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">starter</span><span class="p">)</span> <span class="o">*</span> <span class="mi">1000</span>
  163. <span class="c1"># Return time in milliseconds</span>
  164. <span class="k">return</span> <span class="n">timer</span></div></div>
  165. <div class="viewcode-block" id="AverageMeter"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.AverageMeter">[docs]</a><span class="k">class</span> <span class="nc">AverageMeter</span><span class="p">:</span>
  166. <span class="sd">&quot;&quot;&quot;A class to calculate the average of a metric, for each batch</span>
  167. <span class="sd"> during training/testing&quot;&quot;&quot;</span>
  168. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  169. <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">=</span> <span class="kc">None</span>
  170. <span class="bp">self</span><span class="o">.</span><span class="n">_count</span> <span class="o">=</span> <span class="mi">0</span>
  171. <div class="viewcode-block" id="AverageMeter.update"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.AverageMeter.update">[docs]</a> <span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">],</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
  172. <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
  173. <span class="n">value</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">value</span><span class="p">)</span>
  174. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  175. <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">=</span> <span class="n">value</span> <span class="o">*</span> <span class="n">batch_size</span>
  176. <span class="k">else</span><span class="p">:</span>
  177. <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">+=</span> <span class="n">value</span> <span class="o">*</span> <span class="n">batch_size</span>
  178. <span class="bp">self</span><span class="o">.</span><span class="n">_count</span> <span class="o">+=</span> <span class="n">batch_size</span></div>
  179. <span class="nd">@property</span>
  180. <span class="k">def</span> <span class="nf">average</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
  181. <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
  182. <span class="k">return</span> <span class="mi">0</span>
  183. <span class="k">return</span> <span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_count</span><span class="p">)</span><span class="o">.</span><span class="fm">__float__</span><span class="p">())</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_sum</span><span class="o">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">&lt;</span> <span class="mi">1</span> <span class="k">else</span> <span class="nb">tuple</span><span class="p">(</span>
  184. <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_sum</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_count</span><span class="p">)</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span></div>
  185. <span class="c1"># return (self._sum / self._count).__float__() if self._sum.dim() &lt; 1 or len(self._sum) == 1 \</span>
  186. <span class="c1"># else tuple((self._sum / self._count).cpu().numpy())</span>
  187. <div class="viewcode-block" id="tensor_container_to_device"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.tensor_container_to_device">[docs]</a><span class="k">def</span> <span class="nf">tensor_container_to_device</span><span class="p">(</span><span class="n">obj</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">,</span> <span class="nb">list</span><span class="p">,</span> <span class="nb">dict</span><span class="p">],</span> <span class="n">device</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
  188. <span class="sd">&quot;&quot;&quot;</span>
  189. <span class="sd"> recursively send compounded objects to device (sending all tensors to device and maintaining structure)</span>
  190. <span class="sd"> :param obj the object to send to device (list / tuple / tensor / dict)</span>
  191. <span class="sd"> :param device: device to send the tensors to</span>
  192. <span class="sd"> :param non_blocking: used for DistributedDataParallel</span>
  193. <span class="sd"> :returns an object with the same structure (tensors, lists, tuples) with the device pointers (like</span>
  194. <span class="sd"> the return value of Tensor.to(device)</span>
  195. <span class="sd"> &quot;&quot;&quot;</span>
  196. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
  197. <span class="k">return</span> <span class="n">obj</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="n">non_blocking</span><span class="p">)</span>
  198. <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span>
  199. <span class="k">return</span> <span class="nb">tuple</span><span class="p">(</span><span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="n">non_blocking</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">obj</span><span class="p">)</span>
  200. <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span>
  201. <span class="k">return</span> <span class="p">[</span><span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="n">non_blocking</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">obj</span><span class="p">]</span>
  202. <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
  203. <span class="k">return</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">tensor_container_to_device</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="n">non_blocking</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">obj</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span>
  204. <span class="k">else</span><span class="p">:</span>
  205. <span class="k">return</span> <span class="n">obj</span></div>
  206. <div class="viewcode-block" id="get_param"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.get_param">[docs]</a><span class="k">def</span> <span class="nf">get_param</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">default_val</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
  207. <span class="sd">&quot;&quot;&quot;</span>
  208. <span class="sd"> Retrieves a param from a parameter object/dict. If the parameter does not exist, will return default_val.</span>
  209. <span class="sd"> In case the default_val is of type dictionary, and a value is found in the params - the function</span>
  210. <span class="sd"> will return the default value dictionary with internal values overridden by the found value</span>
  211. <span class="sd"> i.e.</span>
  212. <span class="sd"> default_opt_params = {&#39;lr&#39;:0.1, &#39;momentum&#39;:0.99, &#39;alpha&#39;:0.001}</span>
  213. <span class="sd"> training_params = {&#39;optimizer_params&#39;: {&#39;lr&#39;:0.0001}, &#39;batch&#39;: 32 .... }</span>
  214. <span class="sd"> get_param(training_params, name=&#39;optimizer_params&#39;, default_val=default_opt_params)</span>
  215. <span class="sd"> will return {&#39;lr&#39;:0.0001, &#39;momentum&#39;:0.99, &#39;alpha&#39;:0.001}</span>
  216. <span class="sd"> :param params: an object (typically HpmStruct) or a dict holding the params</span>
  217. <span class="sd"> :param name: name of the searched parameter</span>
  218. <span class="sd"> :param default_val: assumed to be the same type as the value searched in the params</span>
  219. <span class="sd"> :return: the found value, or default if not found</span>
  220. <span class="sd"> &quot;&quot;&quot;</span>
  221. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
  222. <span class="k">if</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">params</span><span class="p">:</span>
  223. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">default_val</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
  224. <span class="k">return</span> <span class="p">{</span><span class="o">**</span><span class="n">default_val</span><span class="p">,</span> <span class="o">**</span><span class="n">params</span><span class="p">[</span><span class="n">name</span><span class="p">]}</span>
  225. <span class="k">else</span><span class="p">:</span>
  226. <span class="k">return</span> <span class="n">params</span><span class="p">[</span><span class="n">name</span><span class="p">]</span>
  227. <span class="k">else</span><span class="p">:</span>
  228. <span class="k">return</span> <span class="n">default_val</span>
  229. <span class="k">elif</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span>
  230. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">default_val</span><span class="p">,</span> <span class="nb">dict</span><span class="p">):</span>
  231. <span class="k">return</span> <span class="p">{</span><span class="o">**</span><span class="n">default_val</span><span class="p">,</span> <span class="o">**</span><span class="nb">getattr</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">name</span><span class="p">)}</span>
  232. <span class="k">else</span><span class="p">:</span>
  233. <span class="k">return</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">name</span><span class="p">)</span>
  234. <span class="k">else</span><span class="p">:</span>
  235. <span class="k">return</span> <span class="n">default_val</span></div>
  236. <div class="viewcode-block" id="static_vars"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.static_vars">[docs]</a><span class="k">def</span> <span class="nf">static_vars</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
  237. <span class="k">def</span> <span class="nf">decorate</span><span class="p">(</span><span class="n">func</span><span class="p">):</span>
  238. <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">kwargs</span><span class="p">:</span>
  239. <span class="nb">setattr</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">kwargs</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
  240. <span class="k">return</span> <span class="n">func</span>
  241. <span class="k">return</span> <span class="n">decorate</span></div>
  242. <div class="viewcode-block" id="print_once"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.print_once">[docs]</a><span class="nd">@static_vars</span><span class="p">(</span><span class="n">printed</span><span class="o">=</span><span class="nb">set</span><span class="p">())</span>
  243. <span class="k">def</span> <span class="nf">print_once</span><span class="p">(</span><span class="n">s</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  244. <span class="k">if</span> <span class="n">s</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">print_once</span><span class="o">.</span><span class="n">printed</span><span class="p">:</span>
  245. <span class="n">print_once</span><span class="o">.</span><span class="n">printed</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">s</span><span class="p">)</span>
  246. <span class="nb">print</span><span class="p">(</span><span class="n">s</span><span class="p">)</span></div>
  247. <div class="viewcode-block" id="move_state_dict_to_device"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.move_state_dict_to_device">[docs]</a><span class="k">def</span> <span class="nf">move_state_dict_to_device</span><span class="p">(</span><span class="n">model_sd</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
  248. <span class="sd">&quot;&quot;&quot;</span>
  249. <span class="sd"> Moving model state dict tensors to target device (cuda or cpu)</span>
  250. <span class="sd"> :param model_sd: model state dict</span>
  251. <span class="sd"> :param device: either cuda or cpu</span>
  252. <span class="sd"> &quot;&quot;&quot;</span>
  253. <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">model_sd</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  254. <span class="n">model_sd</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
  255. <span class="k">return</span> <span class="n">model_sd</span></div>
  256. <div class="viewcode-block" id="random_seed"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.random_seed">[docs]</a><span class="k">def</span> <span class="nf">random_seed</span><span class="p">(</span><span class="n">is_ddp</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">seed</span><span class="p">):</span>
  257. <span class="sd">&quot;&quot;&quot;</span>
  258. <span class="sd"> Sets random seed of numpy, torch and random.</span>
  259. <span class="sd"> When using ddp a seed will be set for each process according to its local rank derived from the device number.</span>
  260. <span class="sd"> :param is_ddp: bool, will set different random seed for each process when using ddp.</span>
  261. <span class="sd"> :param device: &#39;cuda&#39;,&#39;cpu&#39;, &#39;cuda:&lt;device_number&gt;&#39;</span>
  262. <span class="sd"> :param seed: int, random seed to be set</span>
  263. <span class="sd"> &quot;&quot;&quot;</span>
  264. <span class="n">rank</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">is_ddp</span> <span class="k">else</span> <span class="nb">int</span><span class="p">(</span><span class="n">device</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">&#39;:&#39;</span><span class="p">)[</span><span class="mi">1</span><span class="p">])</span>
  265. <span class="n">torch</span><span class="o">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">seed</span> <span class="o">+</span> <span class="n">rank</span><span class="p">)</span>
  266. <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span> <span class="o">+</span> <span class="n">rank</span><span class="p">)</span>
  267. <span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span> <span class="o">+</span> <span class="n">rank</span><span class="p">)</span></div>
  268. <div class="viewcode-block" id="load_func"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.load_func">[docs]</a><span class="k">def</span> <span class="nf">load_func</span><span class="p">(</span><span class="n">dotpath</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  269. <span class="sd">&quot;&quot;&quot;</span>
  270. <span class="sd"> load function in module. function is right-most segment.</span>
  271. <span class="sd"> Used for passing functions (without calling them) in yaml files.</span>
  272. <span class="sd"> @param dotpath: path to module.</span>
  273. <span class="sd"> @return: a python function</span>
  274. <span class="sd"> &quot;&quot;&quot;</span>
  275. <span class="n">module_</span><span class="p">,</span> <span class="n">func</span> <span class="o">=</span> <span class="n">dotpath</span><span class="o">.</span><span class="n">rsplit</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">,</span> <span class="n">maxsplit</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
  276. <span class="n">m</span> <span class="o">=</span> <span class="n">import_module</span><span class="p">(</span><span class="n">module_</span><span class="p">)</span>
  277. <span class="k">return</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">func</span><span class="p">)</span></div>
  278. <div class="viewcode-block" id="get_filename_suffix_by_framework"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.get_filename_suffix_by_framework">[docs]</a><span class="k">def</span> <span class="nf">get_filename_suffix_by_framework</span><span class="p">(</span><span class="n">framework</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  279. <span class="sd">&quot;&quot;&quot;</span>
  280. <span class="sd"> Return the file extension of framework.</span>
  281. <span class="sd"> @param framework: (str)</span>
  282. <span class="sd"> @return: (str) the suffix for the specific framework</span>
  283. <span class="sd"> &quot;&quot;&quot;</span>
  284. <span class="n">frameworks_dict</span> <span class="o">=</span> \
  285. <span class="p">{</span>
  286. <span class="s1">&#39;TENSORFLOW1&#39;</span><span class="p">:</span> <span class="s1">&#39;.pb&#39;</span><span class="p">,</span>
  287. <span class="s1">&#39;TENSORFLOW2&#39;</span><span class="p">:</span> <span class="s1">&#39;.zip&#39;</span><span class="p">,</span>
  288. <span class="s1">&#39;PYTORCH&#39;</span><span class="p">:</span> <span class="s1">&#39;.pth&#39;</span><span class="p">,</span>
  289. <span class="s1">&#39;ONNX&#39;</span><span class="p">:</span> <span class="s1">&#39;.onnx&#39;</span><span class="p">,</span>
  290. <span class="s1">&#39;TENSORRT&#39;</span><span class="p">:</span> <span class="s1">&#39;.pkl&#39;</span><span class="p">,</span>
  291. <span class="s1">&#39;OPENVINO&#39;</span><span class="p">:</span> <span class="s1">&#39;.pkl&#39;</span><span class="p">,</span>
  292. <span class="s1">&#39;TORCHSCRIPT&#39;</span><span class="p">:</span> <span class="s1">&#39;.pth&#39;</span><span class="p">,</span>
  293. <span class="s1">&#39;TVM&#39;</span><span class="p">:</span> <span class="s1">&#39;&#39;</span><span class="p">,</span>
  294. <span class="s1">&#39;KERAS&#39;</span><span class="p">:</span> <span class="s1">&#39;.h5&#39;</span><span class="p">,</span>
  295. <span class="s1">&#39;TFLITE&#39;</span><span class="p">:</span> <span class="s1">&#39;.tflite&#39;</span>
  296. <span class="p">}</span>
  297. <span class="k">if</span> <span class="n">framework</span><span class="o">.</span><span class="n">upper</span><span class="p">()</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">frameworks_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
  298. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Unsupported framework: </span><span class="si">{</span><span class="n">framework</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
  299. <span class="k">return</span> <span class="n">frameworks_dict</span><span class="p">[</span><span class="n">framework</span><span class="o">.</span><span class="n">upper</span><span class="p">()]</span></div>
  300. <div class="viewcode-block" id="check_models_have_same_weights"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.check_models_have_same_weights">[docs]</a><span class="k">def</span> <span class="nf">check_models_have_same_weights</span><span class="p">(</span><span class="n">model_1</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">model_2</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
  301. <span class="sd">&quot;&quot;&quot;</span>
  302. <span class="sd"> Checks whether two networks have the same weights</span>
  303. <span class="sd"> @param model_1: Net to be checked</span>
  304. <span class="sd"> @param model_2: Net to be checked</span>
  305. <span class="sd"> @return: True iff the two networks have the same weights</span>
  306. <span class="sd"> &quot;&quot;&quot;</span>
  307. <span class="n">model_1</span><span class="p">,</span> <span class="n">model_2</span> <span class="o">=</span> <span class="n">model_1</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cpu&#39;</span><span class="p">),</span> <span class="n">model_2</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s1">&#39;cpu&#39;</span><span class="p">)</span>
  308. <span class="n">models_differ</span> <span class="o">=</span> <span class="mi">0</span>
  309. <span class="k">for</span> <span class="n">key_item_1</span><span class="p">,</span> <span class="n">key_item_2</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">model_1</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span><span class="o">.</span><span class="n">items</span><span class="p">(),</span> <span class="n">model_2</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span><span class="o">.</span><span class="n">items</span><span class="p">()):</span>
  310. <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">equal</span><span class="p">(</span><span class="n">key_item_1</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">key_item_2</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span>
  311. <span class="k">pass</span>
  312. <span class="k">else</span><span class="p">:</span>
  313. <span class="n">models_differ</span> <span class="o">+=</span> <span class="mi">1</span>
  314. <span class="k">if</span> <span class="p">(</span><span class="n">key_item_1</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">key_item_2</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
  315. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Layer names match but layers have different weights for layers: </span><span class="si">{</span><span class="n">key_item_1</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
  316. <span class="k">if</span> <span class="n">models_differ</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
  317. <span class="k">return</span> <span class="kc">True</span>
  318. <span class="k">else</span><span class="p">:</span>
  319. <span class="k">return</span> <span class="kc">False</span></div>
  320. <div class="viewcode-block" id="recursive_override"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.recursive_override">[docs]</a><span class="k">def</span> <span class="nf">recursive_override</span><span class="p">(</span><span class="n">base</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">extension</span><span class="p">:</span> <span class="nb">dict</span><span class="p">):</span>
  321. <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">extension</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  322. <span class="k">if</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">base</span><span class="p">:</span>
  323. <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Mapping</span><span class="p">):</span>
  324. <span class="n">recursive_override</span><span class="p">(</span><span class="n">base</span><span class="p">[</span><span class="n">k</span><span class="p">],</span> <span class="n">extension</span><span class="p">[</span><span class="n">k</span><span class="p">])</span>
  325. <span class="k">else</span><span class="p">:</span>
  326. <span class="n">base</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">extension</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>
  327. <span class="k">else</span><span class="p">:</span>
  328. <span class="n">base</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">extension</span><span class="p">[</span><span class="n">k</span><span class="p">]</span></div>
  329. <div class="viewcode-block" id="download_and_unzip_from_url"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.download_and_unzip_from_url">[docs]</a><span class="k">def</span> <span class="nf">download_and_unzip_from_url</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="nb">dir</span><span class="o">=</span><span class="s1">&#39;.&#39;</span><span class="p">,</span> <span class="n">unzip</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">delete</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
  330. <span class="sd">&quot;&quot;&quot;</span>
  331. <span class="sd"> Downloads a zip file from url to dir, and unzips it.</span>
  332. <span class="sd"> :param url: Url to download the file from.</span>
  333. <span class="sd"> :param dir: Destination directory.</span>
  334. <span class="sd"> :param unzip: Whether to unzip the downloaded file.</span>
  335. <span class="sd"> :param delete: Whether to delete the zip file.</span>
  336. <span class="sd"> used to downlaod VOC.</span>
  337. <span class="sd"> Source:</span>
  338. <span class="sd"> https://github.com/ultralytics/yolov5/blob/master/data/VOC.yaml</span>
  339. <span class="sd"> &quot;&quot;&quot;</span>
  340. <span class="k">def</span> <span class="nf">download_one</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="nb">dir</span><span class="p">):</span>
  341. <span class="c1"># Download 1 file</span>
  342. <span class="n">f</span> <span class="o">=</span> <span class="nb">dir</span> <span class="o">/</span> <span class="n">Path</span><span class="p">(</span><span class="n">url</span><span class="p">)</span><span class="o">.</span><span class="n">name</span> <span class="c1"># filename</span>
  343. <span class="k">if</span> <span class="n">Path</span><span class="p">(</span><span class="n">url</span><span class="p">)</span><span class="o">.</span><span class="n">is_file</span><span class="p">():</span> <span class="c1"># exists in current path</span>
  344. <span class="n">Path</span><span class="p">(</span><span class="n">url</span><span class="p">)</span><span class="o">.</span><span class="n">rename</span><span class="p">(</span><span class="n">f</span><span class="p">)</span> <span class="c1"># move to dir</span>
  345. <span class="k">elif</span> <span class="ow">not</span> <span class="n">f</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
  346. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Downloading </span><span class="si">{</span><span class="n">url</span><span class="si">}</span><span class="s1"> to </span><span class="si">{</span><span class="n">f</span><span class="si">}</span><span class="s1">...&#39;</span><span class="p">)</span>
  347. <span class="n">torch</span><span class="o">.</span><span class="n">hub</span><span class="o">.</span><span class="n">download_url_to_file</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">f</span><span class="p">,</span> <span class="n">progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># torch download</span>
  348. <span class="k">if</span> <span class="n">unzip</span> <span class="ow">and</span> <span class="n">f</span><span class="o">.</span><span class="n">suffix</span> <span class="ow">in</span> <span class="p">(</span><span class="s1">&#39;.zip&#39;</span><span class="p">,</span> <span class="s1">&#39;.gz&#39;</span><span class="p">):</span>
  349. <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Unzipping </span><span class="si">{</span><span class="n">f</span><span class="si">}</span><span class="s1">...&#39;</span><span class="p">)</span>
  350. <span class="k">if</span> <span class="n">f</span><span class="o">.</span><span class="n">suffix</span> <span class="o">==</span> <span class="s1">&#39;.zip&#39;</span><span class="p">:</span>
  351. <span class="n">ZipFile</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="n">path</span><span class="o">=</span><span class="nb">dir</span><span class="p">)</span> <span class="c1"># unzip</span>
  352. <span class="k">elif</span> <span class="n">f</span><span class="o">.</span><span class="n">suffix</span> <span class="o">==</span> <span class="s1">&#39;.gz&#39;</span><span class="p">:</span>
  353. <span class="n">os</span><span class="o">.</span><span class="n">system</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;tar xfz </span><span class="si">{</span><span class="n">f</span><span class="si">}</span><span class="s1"> --directory </span><span class="si">{</span><span class="n">f</span><span class="o">.</span><span class="n">parent</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span> <span class="c1"># unzip</span>
  354. <span class="k">if</span> <span class="n">delete</span><span class="p">:</span>
  355. <span class="n">f</span><span class="o">.</span><span class="n">unlink</span><span class="p">()</span> <span class="c1"># remove zip</span>
  356. <span class="nb">dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="nb">dir</span><span class="p">)</span>
  357. <span class="nb">dir</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># make directory</span>
  358. <span class="k">for</span> <span class="n">u</span> <span class="ow">in</span> <span class="p">[</span><span class="n">url</span><span class="p">]</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">))</span> <span class="k">else</span> <span class="n">url</span><span class="p">:</span>
  359. <span class="n">download_one</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="nb">dir</span><span class="p">)</span></div>
  360. <div class="viewcode-block" id="download_and_untar_from_url"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.download_and_untar_from_url">[docs]</a><span class="k">def</span> <span class="nf">download_and_untar_from_url</span><span class="p">(</span><span class="n">urls</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">],</span> <span class="nb">dir</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Path</span><span class="p">]</span> <span class="o">=</span> <span class="s1">&#39;.&#39;</span><span class="p">):</span>
  361. <span class="sd">&quot;&quot;&quot;</span>
  362. <span class="sd"> Download a file from url and untar.</span>
  363. <span class="sd"> :param urls: Url to download the file from.</span>
  364. <span class="sd"> :param dir: Destination directory.</span>
  365. <span class="sd"> &quot;&quot;&quot;</span>
  366. <span class="nb">dir</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="nb">dir</span><span class="p">)</span>
  367. <span class="nb">dir</span><span class="o">.</span><span class="n">mkdir</span><span class="p">(</span><span class="n">parents</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  368. <span class="k">for</span> <span class="n">url</span> <span class="ow">in</span> <span class="n">urls</span><span class="p">:</span>
  369. <span class="n">url_path</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="n">url</span><span class="p">)</span>
  370. <span class="n">filepath</span> <span class="o">=</span> <span class="nb">dir</span> <span class="o">/</span> <span class="n">url_path</span><span class="o">.</span><span class="n">name</span>
  371. <span class="k">if</span> <span class="n">url_path</span><span class="o">.</span><span class="n">is_file</span><span class="p">():</span>
  372. <span class="n">url_path</span><span class="o">.</span><span class="n">rename</span><span class="p">(</span><span class="n">filepath</span><span class="p">)</span>
  373. <span class="k">elif</span> <span class="ow">not</span> <span class="n">filepath</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span>
  374. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Downloading </span><span class="si">{</span><span class="n">url</span><span class="si">}</span><span class="s1"> to </span><span class="si">{</span><span class="n">filepath</span><span class="si">}</span><span class="s1">...&#39;</span><span class="p">)</span>
  375. <span class="n">torch</span><span class="o">.</span><span class="n">hub</span><span class="o">.</span><span class="n">download_url_to_file</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">filepath</span><span class="p">),</span> <span class="n">progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  376. <span class="n">modes</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;.tar.gz&quot;</span><span class="p">:</span> <span class="s2">&quot;r:gz&quot;</span><span class="p">,</span> <span class="s2">&quot;.tar&quot;</span><span class="p">:</span> <span class="s2">&quot;r:&quot;</span><span class="p">}</span>
  377. <span class="k">assert</span> <span class="n">filepath</span><span class="o">.</span><span class="n">suffix</span> <span class="ow">in</span> <span class="n">modes</span><span class="o">.</span><span class="n">keys</span><span class="p">(),</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">filepath</span><span class="si">}</span><span class="s2"> has </span><span class="si">{</span><span class="n">filepath</span><span class="o">.</span><span class="n">suffix</span><span class="si">}</span><span class="s2"> suffix which is not supported&quot;</span>
  378. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Extracting to </span><span class="si">{</span><span class="nb">dir</span><span class="si">}</span><span class="s1">...&#39;</span><span class="p">)</span>
  379. <span class="k">with</span> <span class="n">tarfile</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">filepath</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="n">modes</span><span class="p">[</span><span class="n">filepath</span><span class="o">.</span><span class="n">suffix</span><span class="p">])</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
  380. <span class="n">f</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="nb">dir</span><span class="p">)</span>
  381. <span class="n">filepath</span><span class="o">.</span><span class="n">unlink</span><span class="p">()</span></div>
  382. <div class="viewcode-block" id="make_divisible"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.make_divisible">[docs]</a><span class="k">def</span> <span class="nf">make_divisible</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">divisor</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">ceil</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
  383. <span class="sd">&quot;&quot;&quot;</span>
  384. <span class="sd"> Returns x evenly divisible by divisor.</span>
  385. <span class="sd"> If ceil=True it will return the closest larger number to the original x, and ceil=False the closest smaller number.</span>
  386. <span class="sd"> &quot;&quot;&quot;</span>
  387. <span class="k">if</span> <span class="n">ceil</span><span class="p">:</span>
  388. <span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="n">divisor</span><span class="p">)</span> <span class="o">*</span> <span class="n">divisor</span>
  389. <span class="k">else</span><span class="p">:</span>
  390. <span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="n">divisor</span><span class="p">)</span> <span class="o">*</span> <span class="n">divisor</span></div>
  391. <div class="viewcode-block" id="check_img_size_divisibility"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.check_img_size_divisibility">[docs]</a><span class="k">def</span> <span class="nf">check_img_size_divisibility</span><span class="p">(</span><span class="n">img_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">stride</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">bool</span><span class="p">,</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]]:</span>
  392. <span class="sd">&quot;&quot;&quot;</span>
  393. <span class="sd"> :param img_size: Int, the size of the image (H or W).</span>
  394. <span class="sd"> :param stride: Int, the number to check if img_size is divisible by.</span>
  395. <span class="sd"> :return: (True, None) if img_size is divisble by stride, (False, Suggestions) if it&#39;s not.</span>
  396. <span class="sd"> Note: Suggestions are the two closest numbers to img_size that *are* divisible by stride.</span>
  397. <span class="sd"> For example if img_size=321, stride=32, it will return (False,(352, 320)).</span>
  398. <span class="sd"> &quot;&quot;&quot;</span>
  399. <span class="n">new_size</span> <span class="o">=</span> <span class="n">make_divisible</span><span class="p">(</span><span class="n">img_size</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="n">stride</span><span class="p">))</span>
  400. <span class="k">if</span> <span class="n">new_size</span> <span class="o">!=</span> <span class="n">img_size</span><span class="p">:</span>
  401. <span class="k">return</span> <span class="kc">False</span><span class="p">,</span> <span class="p">(</span><span class="n">new_size</span><span class="p">,</span> <span class="n">make_divisible</span><span class="p">(</span><span class="n">img_size</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="n">stride</span><span class="p">),</span> <span class="n">ceil</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
  402. <span class="k">else</span><span class="p">:</span>
  403. <span class="k">return</span> <span class="kc">True</span><span class="p">,</span> <span class="kc">None</span></div>
  404. <div class="viewcode-block" id="get_orientation_key"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.get_orientation_key">[docs]</a><span class="nd">@lru_cache</span><span class="p">(</span><span class="kc">None</span><span class="p">)</span>
  405. <span class="k">def</span> <span class="nf">get_orientation_key</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
  406. <span class="sd">&quot;&quot;&quot;Get the orientation key according to PIL, which is useful to get the image size for instance</span>
  407. <span class="sd"> :return: Orientation key according to PIL&quot;&quot;&quot;</span>
  408. <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">ExifTags</span><span class="o">.</span><span class="n">TAGS</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
  409. <span class="k">if</span> <span class="n">value</span> <span class="o">==</span> <span class="s1">&#39;Orientation&#39;</span><span class="p">:</span>
  410. <span class="k">return</span> <span class="n">key</span></div>
  411. <div class="viewcode-block" id="exif_size"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.exif_size">[docs]</a><span class="k">def</span> <span class="nf">exif_size</span><span class="p">(</span><span class="n">image</span><span class="p">:</span> <span class="n">Image</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]:</span>
  412. <span class="sd">&quot;&quot;&quot;Get the size of image.</span>
  413. <span class="sd"> :param image: The image to get size from</span>
  414. <span class="sd"> :return: (width, height)</span>
  415. <span class="sd"> &quot;&quot;&quot;</span>
  416. <span class="n">orientation_key</span> <span class="o">=</span> <span class="n">get_orientation_key</span><span class="p">()</span>
  417. <span class="n">image_size</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">size</span>
  418. <span class="k">try</span><span class="p">:</span>
  419. <span class="n">exif_data</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">_getexif</span><span class="p">()</span>
  420. <span class="k">if</span> <span class="n">exif_data</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  421. <span class="n">rotation</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="n">exif_data</span><span class="o">.</span><span class="n">items</span><span class="p">())[</span><span class="n">orientation_key</span><span class="p">]</span>
  422. <span class="c1"># ROTATION 270</span>
  423. <span class="k">if</span> <span class="n">rotation</span> <span class="o">==</span> <span class="mi">6</span><span class="p">:</span>
  424. <span class="n">image_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">image_size</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">image_size</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
  425. <span class="c1"># ROTATION 90</span>
  426. <span class="k">elif</span> <span class="n">rotation</span> <span class="o">==</span> <span class="mi">8</span><span class="p">:</span>
  427. <span class="n">image_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">image_size</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">image_size</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
  428. <span class="k">except</span> <span class="ne">Exception</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
  429. <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Caught Exception trying to rotate: &#39;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">ex</span><span class="p">))</span>
  430. <span class="n">height</span><span class="p">,</span> <span class="n">width</span> <span class="o">=</span> <span class="n">image_size</span>
  431. <span class="k">return</span> <span class="n">width</span><span class="p">,</span> <span class="n">height</span></div>
  432. <div class="viewcode-block" id="get_image_size_from_path"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.get_image_size_from_path">[docs]</a><span class="k">def</span> <span class="nf">get_image_size_from_path</span><span class="p">(</span><span class="n">img_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]:</span>
  433. <span class="sd">&quot;&quot;&quot;Get the image size of an image at a specific path&quot;&quot;&quot;</span>
  434. <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">img_path</span><span class="p">,</span> <span class="s1">&#39;rb&#39;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
  435. <span class="k">return</span> <span class="n">exif_size</span><span class="p">(</span><span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">f</span><span class="p">))</span></div>
  436. </pre></div>
  437. </div>
  438. </div>
  439. <footer>
  440. <hr/>
  441. <div role="contentinfo">
  442. <p>&#169; Copyright 2021, SuperGradients team.</p>
  443. </div>
  444. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  445. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  446. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  447. </footer>
  448. </div>
  449. </div>
  450. </section>
  451. </div>
  452. <script>
  453. jQuery(function () {
  454. SphinxRtdTheme.Navigation.enable(true);
  455. });
  456. </script>
  457. </body>
  458. </html>
Discard
Tip!

Press p or to see the previous file or, n or to see the next file