Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

checkpoint_utils.html 53 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
  1. <!DOCTYPE html>
  2. <html class="writer-html5" lang="en" >
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  6. <title>super_gradients.training.utils.checkpoint_utils &mdash; SuperGradients 3.0.3 documentation</title>
  7. <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
  8. <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
  9. <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
  10. <link rel="stylesheet" href="../../../../_static/custom.css" type="text/css" />
  11. <!--[if lt IE 9]>
  12. <script src="../../../../_static/js/html5shiv.min.js"></script>
  13. <![endif]-->
  14. <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
  15. <script src="../../../../_static/jquery.js"></script>
  16. <script src="../../../../_static/underscore.js"></script>
  17. <script src="../../../../_static/_sphinx_javascript_frameworks_compat.js"></script>
  18. <script src="../../../../_static/doctools.js"></script>
  19. <script src="../../../../_static/sphinx_highlight.js"></script>
  20. <script src="../../../../_static/js/theme.js"></script>
  21. <link rel="index" title="Index" href="../../../../genindex.html" />
  22. <link rel="search" title="Search" href="../../../../search.html" />
  23. </head>
  24. <body class="wy-body-for-nav">
  25. <div class="wy-grid-for-nav">
  26. <nav data-toggle="wy-nav-shift" class="wy-nav-side">
  27. <div class="wy-side-scroll">
  28. <div class="wy-side-nav-search" >
  29. <a href="../../../../index.html" class="icon icon-home"> SuperGradients
  30. </a>
  31. <div role="search">
  32. <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
  33. <input type="text" name="q" placeholder="Search docs" />
  34. <input type="hidden" name="check_keywords" value="yes" />
  35. <input type="hidden" name="area" value="default" />
  36. </form>
  37. </div>
  38. </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
  39. <p class="caption" role="heading"><span class="caption-text">Welcome To SuperGradients</span></p>
  40. <ul>
  41. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Version 3 is out! Notebooks have been updated!</a></li>
  42. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#build-with-supergradients">Build with SuperGradients</a></li>
  43. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#quick-installation">Quick Installation</a></li>
  44. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#what-s-new">What’s New</a></li>
  45. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#coming-soon">Coming soon</a></li>
  46. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#table-of-content">Table of Content</a></li>
  47. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#getting-started">Getting Started</a></li>
  48. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#advanced-features">Advanced Features</a></li>
  49. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#installation-methods">Installation Methods</a></li>
  50. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#implemented-model-architectures">Implemented Model Architectures</a></li>
  51. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#documentation">Documentation</a></li>
  52. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#contributing">Contributing</a></li>
  53. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#citation">Citation</a></li>
  54. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#community">Community</a></li>
  55. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#license">License</a></li>
  56. <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#deci-platform">Deci Platform</a></li>
  57. </ul>
  58. <p class="caption" role="heading"><span class="caption-text">Technical Documentation</span></p>
  59. <ul>
  60. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
  61. <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
  62. </ul>
  63. </div>
  64. </div>
  65. </nav>
  66. <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
  67. <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
  68. <a href="../../../../index.html">SuperGradients</a>
  69. </nav>
  70. <div class="wy-nav-content">
  71. <div class="rst-content">
  72. <div role="navigation" aria-label="Page navigation">
  73. <ul class="wy-breadcrumbs">
  74. <li><a href="../../../../index.html" class="icon icon-home"></a> &raquo;</li>
  75. <li><a href="../../../index.html">Module code</a> &raquo;</li>
  76. <li>super_gradients.training.utils.checkpoint_utils</li>
  77. <li class="wy-breadcrumbs-aside">
  78. </li>
  79. </ul>
  80. <hr/>
  81. </div>
  82. <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
  83. <div itemprop="articleBody">
  84. <h1>Source code for super_gradients.training.utils.checkpoint_utils</h1><div class="highlight"><pre>
  85. <span></span><span class="kn">import</span> <span class="nn">os</span>
  86. <span class="kn">import</span> <span class="nn">tempfile</span>
  87. <span class="kn">import</span> <span class="nn">pkg_resources</span>
  88. <span class="kn">import</span> <span class="nn">torch</span>
  89. <span class="kn">from</span> <span class="nn">super_gradients.common.abstractions.abstract_logger</span> <span class="kn">import</span> <span class="n">get_logger</span>
  90. <span class="kn">from</span> <span class="nn">super_gradients.common</span> <span class="kn">import</span> <span class="n">explicit_params_validation</span><span class="p">,</span> <span class="n">ADNNModelRepositoryDataInterfaces</span>
  91. <span class="kn">from</span> <span class="nn">super_gradients.training.pretrained_models</span> <span class="kn">import</span> <span class="n">MODEL_URLS</span>
  92. <span class="kn">from</span> <span class="nn">super_gradients.common.environment</span> <span class="kn">import</span> <span class="n">environment_config</span>
  93. <span class="k">try</span><span class="p">:</span>
  94. <span class="kn">from</span> <span class="nn">torch.hub</span> <span class="kn">import</span> <span class="n">download_url_to_file</span><span class="p">,</span> <span class="n">load_state_dict_from_url</span>
  95. <span class="k">except</span> <span class="p">(</span><span class="ne">ModuleNotFoundError</span><span class="p">,</span> <span class="ne">ImportError</span><span class="p">,</span> <span class="ne">NameError</span><span class="p">):</span>
  96. <span class="kn">from</span> <span class="nn">torch.hub</span> <span class="kn">import</span> <span class="n">_download_url_to_file</span> <span class="k">as</span> <span class="n">download_url_to_file</span>
  97. <span class="n">logger</span> <span class="o">=</span> <span class="n">get_logger</span><span class="p">(</span><span class="vm">__name__</span><span class="p">)</span>
  98. <span class="k">def</span> <span class="nf">get_checkpoints_dir_path</span><span class="p">(</span><span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">ckpt_root_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  99. <span class="sd">&quot;&quot;&quot;Creating the checkpoint directory of a given experiment.</span>
  100. <span class="sd"> :param experiment_name: Name of the experiment.</span>
  101. <span class="sd"> :param ckpt_root_dir: Local root directory path where all experiment logging directories will</span>
  102. <span class="sd"> reside. When none is give, it is assumed that pkg_resources.resource_filename(&#39;checkpoints&#39;, &quot;&quot;)</span>
  103. <span class="sd"> exists and will be used.</span>
  104. <span class="sd"> :return: checkpoints_dir_path</span>
  105. <span class="sd"> &quot;&quot;&quot;</span>
  106. <span class="k">if</span> <span class="n">ckpt_root_dir</span><span class="p">:</span>
  107. <span class="k">return</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">ckpt_root_dir</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">)</span>
  108. <span class="k">elif</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">environment_config</span><span class="o">.</span><span class="n">PKG_CHECKPOINTS_DIR</span><span class="p">):</span>
  109. <span class="k">return</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">environment_config</span><span class="o">.</span><span class="n">PKG_CHECKPOINTS_DIR</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">)</span>
  110. <span class="k">else</span><span class="p">:</span>
  111. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Illegal checkpoints directory: pass ckpt_root_dir that exists, or add &#39;checkpoints&#39; to resources.&quot;</span><span class="p">)</span>
  112. <span class="k">def</span> <span class="nf">get_ckpt_local_path</span><span class="p">(</span><span class="n">source_ckpt_folder_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">experiment_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">ckpt_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">external_checkpoint_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  113. <span class="sd">&quot;&quot;&quot;</span>
  114. <span class="sd"> Gets the local path to the checkpoint file, which will be:</span>
  115. <span class="sd"> - By default: YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name.</span>
  116. <span class="sd"> - if the checkpoint file is remotely located:</span>
  117. <span class="sd"> when overwrite_local_checkpoint=True then it will be saved in a temporary path which will be returned,</span>
  118. <span class="sd"> otherwise it will be downloaded to YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name and overwrite</span>
  119. <span class="sd"> YOUR_REPO_ROOT/super_gradients/checkpoints/experiment_name/ckpt_name if such file exists.</span>
  120. <span class="sd"> - external_checkpoint_path when external_checkpoint_path != None</span>
  121. <span class="sd"> @param source_ckpt_folder_name: The folder where the checkpoint is saved. When set to None- uses the experiment_name.</span>
  122. <span class="sd"> @param experiment_name: experiment name attr in trainer</span>
  123. <span class="sd"> @param ckpt_name: checkpoint filename</span>
  124. <span class="sd"> @param external_checkpoint_path: full path to checkpoint file (that might be located outside of super_gradients/checkpoints directory)</span>
  125. <span class="sd"> @return:</span>
  126. <span class="sd"> &quot;&quot;&quot;</span>
  127. <span class="k">if</span> <span class="n">external_checkpoint_path</span><span class="p">:</span>
  128. <span class="k">return</span> <span class="n">external_checkpoint_path</span>
  129. <span class="k">else</span><span class="p">:</span>
  130. <span class="n">checkpoints_folder_name</span> <span class="o">=</span> <span class="n">source_ckpt_folder_name</span> <span class="ow">or</span> <span class="n">experiment_name</span>
  131. <span class="n">checkpoints_dir_path</span> <span class="o">=</span> <span class="n">get_checkpoints_dir_path</span><span class="p">(</span><span class="n">checkpoints_folder_name</span><span class="p">)</span>
  132. <span class="k">return</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">checkpoints_dir_path</span><span class="p">,</span> <span class="n">ckpt_name</span><span class="p">)</span>
  133. <span class="k">def</span> <span class="nf">adaptive_load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">state_dict</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">strict</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  134. <span class="sd">&quot;&quot;&quot;</span>
  135. <span class="sd"> Adaptively loads state_dict to net, by adapting the state_dict to net&#39;s layer names first.</span>
  136. <span class="sd"> @param net: (nn.Module) to load state_dict to</span>
  137. <span class="sd"> @param state_dict: (dict) Chekpoint state_dict</span>
  138. <span class="sd"> @param strict: (str) key matching strictness</span>
  139. <span class="sd"> @return:</span>
  140. <span class="sd"> &quot;&quot;&quot;</span>
  141. <span class="k">try</span><span class="p">:</span>
  142. <span class="n">net</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span> <span class="k">if</span> <span class="s2">&quot;net&quot;</span> <span class="ow">in</span> <span class="n">state_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span> <span class="k">else</span> <span class="n">state_dict</span><span class="p">,</span> <span class="n">strict</span><span class="o">=</span><span class="n">strict</span><span class="p">)</span>
  143. <span class="k">except</span> <span class="p">(</span><span class="ne">RuntimeError</span><span class="p">,</span> <span class="ne">ValueError</span><span class="p">,</span> <span class="ne">KeyError</span><span class="p">)</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span>
  144. <span class="k">if</span> <span class="n">strict</span> <span class="o">==</span> <span class="s2">&quot;no_key_matching&quot;</span><span class="p">:</span>
  145. <span class="n">adapted_state_dict</span> <span class="o">=</span> <span class="n">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">state_dict</span><span class="p">)</span>
  146. <span class="n">net</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">adapted_state_dict</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">],</span> <span class="n">strict</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  147. <span class="k">else</span><span class="p">:</span>
  148. <span class="n">raise_informative_runtime_error</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">state_dict</span><span class="p">,</span> <span class="n">ex</span><span class="p">)</span>
  149. <span class="nd">@explicit_params_validation</span><span class="p">(</span><span class="n">validation_type</span><span class="o">=</span><span class="s2">&quot;None&quot;</span><span class="p">)</span>
  150. <span class="k">def</span> <span class="nf">copy_ckpt_to_local_folder</span><span class="p">(</span>
  151. <span class="n">local_ckpt_destination_dir</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
  152. <span class="n">ckpt_filename</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
  153. <span class="n">remote_ckpt_source_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
  154. <span class="n">path_src</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;local&quot;</span><span class="p">,</span>
  155. <span class="n">overwrite_local_ckpt</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
  156. <span class="n">load_weights_only</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
  157. <span class="p">):</span>
  158. <span class="sd">&quot;&quot;&quot;</span>
  159. <span class="sd"> Copy the checkpoint from any supported source to a local destination path</span>
  160. <span class="sd"> :param local_ckpt_destination_dir: destination where the checkpoint will be saved to</span>
  161. <span class="sd"> :param ckpt_filename: ckpt_best.pth Or ckpt_latest.pth</span>
  162. <span class="sd"> :param remote_ckpt_source_dir: Name of the source checkpoint to be loaded (S3 Model\full URL)</span>
  163. <span class="sd"> :param path_src: S3 / url</span>
  164. <span class="sd"> :param overwrite_local_ckpt: determines if checkpoint will be saved in destination dir or in a temp folder</span>
  165. <span class="sd"> :return: Path to checkpoint</span>
  166. <span class="sd"> &quot;&quot;&quot;</span>
  167. <span class="n">ckpt_file_full_local_path</span> <span class="o">=</span> <span class="kc">None</span>
  168. <span class="c1"># IF NOT DEFINED - IT IS SET TO THE TARGET&#39;s FOLDER NAME</span>
  169. <span class="n">remote_ckpt_source_dir</span> <span class="o">=</span> <span class="n">local_ckpt_destination_dir</span> <span class="k">if</span> <span class="n">remote_ckpt_source_dir</span> <span class="ow">is</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">remote_ckpt_source_dir</span>
  170. <span class="k">if</span> <span class="ow">not</span> <span class="n">overwrite_local_ckpt</span><span class="p">:</span>
  171. <span class="c1"># CREATE A TEMP FOLDER TO SAVE THE CHECKPOINT TO</span>
  172. <span class="n">download_ckpt_destination_dir</span> <span class="o">=</span> <span class="n">tempfile</span><span class="o">.</span><span class="n">gettempdir</span><span class="p">()</span>
  173. <span class="nb">print</span><span class="p">(</span>
  174. <span class="s2">&quot;PLEASE NOTICE - YOU ARE IMPORTING A REMOTE CHECKPOINT WITH overwrite_local_checkpoint = False &quot;</span>
  175. <span class="s2">&quot;-&gt; IT WILL BE REDIRECTED TO A TEMP FOLDER AND DELETED ON MACHINE RESTART&quot;</span>
  176. <span class="p">)</span>
  177. <span class="k">else</span><span class="p">:</span>
  178. <span class="c1"># SAVE THE CHECKPOINT TO MODEL&#39;s FOLDER</span>
  179. <span class="n">download_ckpt_destination_dir</span> <span class="o">=</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s2">&quot;checkpoints&quot;</span><span class="p">,</span> <span class="n">local_ckpt_destination_dir</span><span class="p">)</span>
  180. <span class="k">if</span> <span class="n">path_src</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">&quot;s3&quot;</span><span class="p">):</span>
  181. <span class="n">model_checkpoints_data_interface</span> <span class="o">=</span> <span class="n">ADNNModelRepositoryDataInterfaces</span><span class="p">(</span><span class="n">data_connection_location</span><span class="o">=</span><span class="n">path_src</span><span class="p">)</span>
  182. <span class="c1"># DOWNLOAD THE FILE FROM S3 TO THE DESTINATION FOLDER</span>
  183. <span class="n">ckpt_file_full_local_path</span> <span class="o">=</span> <span class="n">model_checkpoints_data_interface</span><span class="o">.</span><span class="n">load_remote_checkpoints_file</span><span class="p">(</span>
  184. <span class="n">ckpt_source_remote_dir</span><span class="o">=</span><span class="n">remote_ckpt_source_dir</span><span class="p">,</span>
  185. <span class="n">ckpt_destination_local_dir</span><span class="o">=</span><span class="n">download_ckpt_destination_dir</span><span class="p">,</span>
  186. <span class="n">ckpt_file_name</span><span class="o">=</span><span class="n">ckpt_filename</span><span class="p">,</span>
  187. <span class="n">overwrite_local_checkpoints_file</span><span class="o">=</span><span class="n">overwrite_local_ckpt</span><span class="p">,</span>
  188. <span class="p">)</span>
  189. <span class="k">if</span> <span class="ow">not</span> <span class="n">load_weights_only</span><span class="p">:</span>
  190. <span class="c1"># COPY LOG FILES FROM THE REMOTE DIRECTORY TO THE LOCAL ONE ONLY IF LOADING THE CURRENT MODELs CKPT</span>
  191. <span class="n">model_checkpoints_data_interface</span><span class="o">.</span><span class="n">load_all_remote_log_files</span><span class="p">(</span>
  192. <span class="n">model_name</span><span class="o">=</span><span class="n">remote_ckpt_source_dir</span><span class="p">,</span> <span class="n">model_checkpoint_local_dir</span><span class="o">=</span><span class="n">download_ckpt_destination_dir</span>
  193. <span class="p">)</span>
  194. <span class="k">if</span> <span class="n">path_src</span> <span class="o">==</span> <span class="s2">&quot;url&quot;</span><span class="p">:</span>
  195. <span class="n">ckpt_file_full_local_path</span> <span class="o">=</span> <span class="n">download_ckpt_destination_dir</span> <span class="o">+</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">sep</span> <span class="o">+</span> <span class="n">ckpt_filename</span>
  196. <span class="c1"># DOWNLOAD THE FILE FROM URL TO THE DESTINATION FOLDER</span>
  197. <span class="n">download_url_to_file</span><span class="p">(</span><span class="n">remote_ckpt_source_dir</span><span class="p">,</span> <span class="n">ckpt_file_full_local_path</span><span class="p">,</span> <span class="n">progress</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  198. <span class="k">return</span> <span class="n">ckpt_file_full_local_path</span>
  199. <span class="k">def</span> <span class="nf">read_ckpt_state_dict</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cpu&quot;</span><span class="p">):</span>
  200. <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">):</span>
  201. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Incorrect Checkpoint path&quot;</span><span class="p">)</span>
  202. <span class="k">if</span> <span class="n">device</span> <span class="o">==</span> <span class="s2">&quot;cuda&quot;</span><span class="p">:</span>
  203. <span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">)</span>
  204. <span class="k">else</span><span class="p">:</span>
  205. <span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">,</span> <span class="n">map_location</span><span class="o">=</span><span class="k">lambda</span> <span class="n">storage</span><span class="p">,</span> <span class="n">loc</span><span class="p">:</span> <span class="n">storage</span><span class="p">)</span>
  206. <span class="k">return</span> <span class="n">state_dict</span>
  207. <div class="viewcode-block" id="adapt_state_dict_to_fit_model_layer_names"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.adapt_state_dict_to_fit_model_layer_names">[docs]</a><span class="k">def</span> <span class="nf">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">model_state_dict</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">source_ckpt</span><span class="p">:</span> <span class="nb">dict</span><span class="p">,</span> <span class="n">exclude</span><span class="p">:</span> <span class="nb">list</span> <span class="o">=</span> <span class="p">[],</span> <span class="n">solver</span><span class="p">:</span> <span class="n">callable</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span>
  208. <span class="sd">&quot;&quot;&quot;</span>
  209. <span class="sd"> Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit</span>
  210. <span class="sd"> the ckpt in order to properly load the weights into the model. If unsuccessful - returns None</span>
  211. <span class="sd"> :param model_state_dict: the model state_dict</span>
  212. <span class="sd"> :param source_ckpt: checkpoint dict</span>
  213. <span class="sd"> :param exclude optional list for excluded layers</span>
  214. <span class="sd"> :param solver: callable with signature (ckpt_key, ckpt_val, model_key, model_val)</span>
  215. <span class="sd"> that returns a desired weight for ckpt_val.</span>
  216. <span class="sd"> :return: renamed checkpoint dict (if possible)</span>
  217. <span class="sd"> &quot;&quot;&quot;</span>
  218. <span class="k">if</span> <span class="s2">&quot;net&quot;</span> <span class="ow">in</span> <span class="n">source_ckpt</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
  219. <span class="n">source_ckpt</span> <span class="o">=</span> <span class="n">source_ckpt</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span>
  220. <span class="n">model_state_dict_excluded</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">model_state_dict</span><span class="o">.</span><span class="n">items</span><span class="p">()</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">any</span><span class="p">(</span><span class="n">x</span> <span class="ow">in</span> <span class="n">k</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">exclude</span><span class="p">)}</span>
  221. <span class="n">new_ckpt_dict</span> <span class="o">=</span> <span class="p">{}</span>
  222. <span class="k">for</span> <span class="p">(</span><span class="n">ckpt_key</span><span class="p">,</span> <span class="n">ckpt_val</span><span class="p">),</span> <span class="p">(</span><span class="n">model_key</span><span class="p">,</span> <span class="n">model_val</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">source_ckpt</span><span class="o">.</span><span class="n">items</span><span class="p">(),</span> <span class="n">model_state_dict_excluded</span><span class="o">.</span><span class="n">items</span><span class="p">()):</span>
  223. <span class="k">if</span> <span class="n">solver</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  224. <span class="n">ckpt_val</span> <span class="o">=</span> <span class="n">solver</span><span class="p">(</span><span class="n">ckpt_key</span><span class="p">,</span> <span class="n">ckpt_val</span><span class="p">,</span> <span class="n">model_key</span><span class="p">,</span> <span class="n">model_val</span><span class="p">)</span>
  225. <span class="k">if</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="n">model_val</span><span class="o">.</span><span class="n">shape</span><span class="p">:</span>
  226. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;ckpt layer </span><span class="si">{</span><span class="n">ckpt_key</span><span class="si">}</span><span class="s2"> with shape </span><span class="si">{</span><span class="n">ckpt_val</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2"> does not match </span><span class="si">{</span><span class="n">model_key</span><span class="si">}</span><span class="s2">&quot;</span> <span class="sa">f</span><span class="s2">&quot; with shape </span><span class="si">{</span><span class="n">model_val</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2"> in the model&quot;</span><span class="p">)</span>
  227. <span class="n">new_ckpt_dict</span><span class="p">[</span><span class="n">model_key</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span>
  228. <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;net&quot;</span><span class="p">:</span> <span class="n">new_ckpt_dict</span><span class="p">}</span></div>
  229. <div class="viewcode-block" id="raise_informative_runtime_error"><a class="viewcode-back" href="../../../../super_gradients.training.html#super_gradients.training.utils.raise_informative_runtime_error">[docs]</a><span class="k">def</span> <span class="nf">raise_informative_runtime_error</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">,</span> <span class="n">exception_msg</span><span class="p">):</span>
  230. <span class="sd">&quot;&quot;&quot;</span>
  231. <span class="sd"> Given a model state dict and source checkpoints, the method calls &quot;adapt_state_dict_to_fit_model_layer_names&quot;</span>
  232. <span class="sd"> and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible</span>
  233. <span class="sd"> &quot;&quot;&quot;</span>
  234. <span class="k">try</span><span class="p">:</span>
  235. <span class="n">new_ckpt_dict</span> <span class="o">=</span> <span class="n">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">)</span>
  236. <span class="n">temp_file</span> <span class="o">=</span> <span class="n">tempfile</span><span class="o">.</span><span class="n">NamedTemporaryFile</span><span class="p">()</span><span class="o">.</span><span class="n">name</span> <span class="o">+</span> <span class="s2">&quot;.pt&quot;</span>
  237. <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">new_ckpt_dict</span><span class="p">,</span> <span class="n">temp_file</span><span class="p">)</span>
  238. <span class="n">exception_msg</span> <span class="o">=</span> <span class="p">(</span>
  239. <span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="se">\n</span><span class="si">{</span><span class="nb">str</span><span class="p">(</span><span class="n">exception_msg</span><span class="p">)</span><span class="si">}</span><span class="s2"> </span><span class="se">\n</span><span class="s2">convert ckpt via the utils.adapt_state_dict_to_fit_&quot;</span>
  240. <span class="sa">f</span><span class="s2">&quot;model_layer_names method</span><span class="se">\n</span><span class="s2">a converted checkpoint file was saved in the path </span><span class="si">{</span><span class="n">temp_file</span><span class="si">}</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="s2">&quot;</span>
  241. <span class="p">)</span>
  242. <span class="k">except</span> <span class="ne">ValueError</span> <span class="k">as</span> <span class="n">ex</span><span class="p">:</span> <span class="c1"># IN CASE adapt_state_dict_to_fit_model_layer_names WAS UNSUCCESSFUL</span>
  243. <span class="n">exception_msg</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="s2"> </span><span class="se">\n</span><span class="s2">The checkpoint and model shapes do no fit, e.g.: </span><span class="si">{</span><span class="n">ex</span><span class="si">}</span><span class="se">\n</span><span class="si">{</span><span class="s1">&#39;=&#39;</span> <span class="o">*</span> <span class="mi">200</span><span class="si">}</span><span class="s2">&quot;</span>
  244. <span class="k">finally</span><span class="p">:</span>
  245. <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="n">exception_msg</span><span class="p">)</span></div>
  246. <span class="k">def</span> <span class="nf">load_checkpoint_to_model</span><span class="p">(</span>
  247. <span class="n">ckpt_local_path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">load_backbone</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">net</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">strict</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">load_weights_only</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">load_ema_as_net</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span>
  248. <span class="p">):</span>
  249. <span class="sd">&quot;&quot;&quot;</span>
  250. <span class="sd"> Loads the state dict in ckpt_local_path to net and returns the checkpoint&#39;s state dict.</span>
  251. <span class="sd"> @param load_ema_as_net: Will load the EMA inside the checkpoint file to the network when set</span>
  252. <span class="sd"> @param ckpt_local_path: local path to the checkpoint file</span>
  253. <span class="sd"> @param load_backbone: whether to load the checkpoint as a backbone</span>
  254. <span class="sd"> @param net: network to load the checkpoint to</span>
  255. <span class="sd"> @param strict:</span>
  256. <span class="sd"> @param load_weights_only:</span>
  257. <span class="sd"> @return:</span>
  258. <span class="sd"> &quot;&quot;&quot;</span>
  259. <span class="k">if</span> <span class="n">ckpt_local_path</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">or</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">ckpt_local_path</span><span class="p">):</span>
  260. <span class="n">error_msg</span> <span class="o">=</span> <span class="s2">&quot;Error - loading Model Checkpoint: Path </span><span class="si">{}</span><span class="s2"> does not exist&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">ckpt_local_path</span><span class="p">)</span>
  261. <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="n">error_msg</span><span class="p">)</span>
  262. <span class="k">if</span> <span class="n">load_backbone</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="s2">&quot;backbone&quot;</span><span class="p">):</span>
  263. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;No backbone attribute in net - Can&#39;t load backbone weights&quot;</span><span class="p">)</span>
  264. <span class="c1"># LOAD THE LOCAL CHECKPOINT PATH INTO A state_dict OBJECT</span>
  265. <span class="n">checkpoint</span> <span class="o">=</span> <span class="n">read_ckpt_state_dict</span><span class="p">(</span><span class="n">ckpt_path</span><span class="o">=</span><span class="n">ckpt_local_path</span><span class="p">)</span>
  266. <span class="k">if</span> <span class="n">load_ema_as_net</span><span class="p">:</span>
  267. <span class="k">if</span> <span class="s2">&quot;ema_net&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">checkpoint</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
  268. <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Can&#39;t load ema network- no EMA network stored in checkpoint file&quot;</span><span class="p">)</span>
  269. <span class="k">else</span><span class="p">:</span>
  270. <span class="n">checkpoint</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">checkpoint</span><span class="p">[</span><span class="s2">&quot;ema_net&quot;</span><span class="p">]</span>
  271. <span class="c1"># LOAD THE CHECKPOINTS WEIGHTS TO THE MODEL</span>
  272. <span class="k">if</span> <span class="n">load_backbone</span><span class="p">:</span>
  273. <span class="n">adaptive_load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="o">.</span><span class="n">backbone</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">,</span> <span class="n">strict</span><span class="p">)</span>
  274. <span class="k">else</span><span class="p">:</span>
  275. <span class="n">adaptive_load_state_dict</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">checkpoint</span><span class="p">,</span> <span class="n">strict</span><span class="p">)</span>
  276. <span class="n">message_suffix</span> <span class="o">=</span> <span class="s2">&quot; checkpoint.&quot;</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">load_ema_as_net</span> <span class="k">else</span> <span class="s2">&quot; EMA checkpoint.&quot;</span>
  277. <span class="n">message_model</span> <span class="o">=</span> <span class="s2">&quot;model&quot;</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">load_backbone</span> <span class="k">else</span> <span class="s2">&quot;model&#39;s backbone&quot;</span>
  278. <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="s2">&quot;Successfully loaded &quot;</span> <span class="o">+</span> <span class="n">message_model</span> <span class="o">+</span> <span class="s2">&quot; weights from &quot;</span> <span class="o">+</span> <span class="n">ckpt_local_path</span> <span class="o">+</span> <span class="n">message_suffix</span><span class="p">)</span>
  279. <span class="k">if</span> <span class="n">load_weights_only</span> <span class="ow">or</span> <span class="n">load_backbone</span><span class="p">:</span>
  280. <span class="c1"># DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS</span>
  281. <span class="p">[</span><span class="n">checkpoint</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="n">checkpoint</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="k">if</span> <span class="n">key</span> <span class="o">!=</span> <span class="s2">&quot;net&quot;</span><span class="p">]</span>
  282. <span class="k">return</span> <span class="n">checkpoint</span>
  283. <span class="k">class</span> <span class="nc">MissingPretrainedWeightsException</span><span class="p">(</span><span class="ne">Exception</span><span class="p">):</span>
  284. <span class="sd">&quot;&quot;&quot;Exception raised by unsupported pretrianed model.</span>
  285. <span class="sd"> Attributes:</span>
  286. <span class="sd"> message -- explanation of the error</span>
  287. <span class="sd"> &quot;&quot;&quot;</span>
  288. <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">desc</span><span class="p">):</span>
  289. <span class="bp">self</span><span class="o">.</span><span class="n">message</span> <span class="o">=</span> <span class="s2">&quot;Missing pretrained wights: &quot;</span> <span class="o">+</span> <span class="n">desc</span>
  290. <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">message</span><span class="p">)</span>
  291. <span class="k">def</span> <span class="nf">_yolox_ckpt_solver</span><span class="p">(</span><span class="n">ckpt_key</span><span class="p">,</span> <span class="n">ckpt_val</span><span class="p">,</span> <span class="n">model_key</span><span class="p">,</span> <span class="n">model_val</span><span class="p">):</span>
  292. <span class="sd">&quot;&quot;&quot;</span>
  293. <span class="sd"> Helper method for reshaping old pretrained checkpoint&#39;s focus weights to 6x6 conv weights.</span>
  294. <span class="sd"> &quot;&quot;&quot;</span>
  295. <span class="k">if</span> <span class="p">(</span>
  296. <span class="n">ckpt_val</span><span class="o">.</span><span class="n">shape</span> <span class="o">!=</span> <span class="n">model_val</span><span class="o">.</span><span class="n">shape</span>
  297. <span class="ow">and</span> <span class="n">ckpt_key</span> <span class="o">==</span> <span class="s2">&quot;module._backbone._modules_list.0.conv.conv.weight&quot;</span>
  298. <span class="ow">and</span> <span class="n">model_key</span> <span class="o">==</span> <span class="s2">&quot;_backbone._modules_list.0.conv.weight&quot;</span>
  299. <span class="p">):</span>
  300. <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">3</span><span class="p">]</span>
  301. <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="mi">3</span><span class="p">:</span><span class="mi">6</span><span class="p">]</span>
  302. <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="mi">6</span><span class="p">:</span><span class="mi">9</span><span class="p">]</span>
  303. <span class="n">model_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="n">ckpt_val</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="mi">9</span><span class="p">:</span><span class="mi">12</span><span class="p">]</span>
  304. <span class="n">replacement</span> <span class="o">=</span> <span class="n">model_val</span>
  305. <span class="k">else</span><span class="p">:</span>
  306. <span class="n">replacement</span> <span class="o">=</span> <span class="n">ckpt_val</span>
  307. <span class="k">return</span> <span class="n">replacement</span>
  308. <span class="k">def</span> <span class="nf">load_pretrained_weights</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">architecture</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">pretrained_weights</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  309. <span class="sd">&quot;&quot;&quot;</span>
  310. <span class="sd"> Loads pretrained weights from the MODEL_URLS dictionary to model</span>
  311. <span class="sd"> @param architecture: name of the model&#39;s architecture</span>
  312. <span class="sd"> @param model: model to load pretrinaed weights for</span>
  313. <span class="sd"> @param pretrained_weights: name for the pretrianed weights (i.e imagenet)</span>
  314. <span class="sd"> @return: None</span>
  315. <span class="sd"> &quot;&quot;&quot;</span>
  316. <span class="n">model_url_key</span> <span class="o">=</span> <span class="n">architecture</span> <span class="o">+</span> <span class="s2">&quot;_&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">pretrained_weights</span><span class="p">)</span>
  317. <span class="k">if</span> <span class="n">model_url_key</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">MODEL_URLS</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
  318. <span class="k">raise</span> <span class="n">MissingPretrainedWeightsException</span><span class="p">(</span><span class="n">model_url_key</span><span class="p">)</span>
  319. <span class="n">url</span> <span class="o">=</span> <span class="n">MODEL_URLS</span><span class="p">[</span><span class="n">model_url_key</span><span class="p">]</span>
  320. <span class="n">unique_filename</span> <span class="o">=</span> <span class="n">url</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;https://deci-pretrained-models.s3.amazonaws.com/&quot;</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;/&quot;</span><span class="p">,</span> <span class="s2">&quot;_&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot; &quot;</span><span class="p">,</span> <span class="s2">&quot;_&quot;</span><span class="p">)</span>
  321. <span class="n">map_location</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
  322. <span class="n">pretrained_state_dict</span> <span class="o">=</span> <span class="n">load_state_dict_from_url</span><span class="p">(</span><span class="n">url</span><span class="o">=</span><span class="n">url</span><span class="p">,</span> <span class="n">map_location</span><span class="o">=</span><span class="n">map_location</span><span class="p">,</span> <span class="n">file_name</span><span class="o">=</span><span class="n">unique_filename</span><span class="p">)</span>
  323. <span class="n">_load_weights</span><span class="p">(</span><span class="n">architecture</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">pretrained_state_dict</span><span class="p">)</span>
  324. <span class="k">def</span> <span class="nf">_load_weights</span><span class="p">(</span><span class="n">architecture</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">pretrained_state_dict</span><span class="p">):</span>
  325. <span class="k">if</span> <span class="s2">&quot;ema_net&quot;</span> <span class="ow">in</span> <span class="n">pretrained_state_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">():</span>
  326. <span class="n">pretrained_state_dict</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">pretrained_state_dict</span><span class="p">[</span><span class="s2">&quot;ema_net&quot;</span><span class="p">]</span>
  327. <span class="n">solver</span> <span class="o">=</span> <span class="n">_yolox_ckpt_solver</span> <span class="k">if</span> <span class="s2">&quot;yolox&quot;</span> <span class="ow">in</span> <span class="n">architecture</span> <span class="k">else</span> <span class="kc">None</span>
  328. <span class="n">adapted_pretrained_state_dict</span> <span class="o">=</span> <span class="n">adapt_state_dict_to_fit_model_layer_names</span><span class="p">(</span>
  329. <span class="n">model_state_dict</span><span class="o">=</span><span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="n">source_ckpt</span><span class="o">=</span><span class="n">pretrained_state_dict</span><span class="p">,</span> <span class="n">solver</span><span class="o">=</span><span class="n">solver</span>
  330. <span class="p">)</span>
  331. <span class="n">model</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">adapted_pretrained_state_dict</span><span class="p">[</span><span class="s2">&quot;net&quot;</span><span class="p">],</span> <span class="n">strict</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
  332. <span class="k">def</span> <span class="nf">load_pretrained_weights_local</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">architecture</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">pretrained_weights</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
  333. <span class="sd">&quot;&quot;&quot;</span>
  334. <span class="sd"> Loads pretrained weights from the MODEL_URLS dictionary to model</span>
  335. <span class="sd"> @param architecture: name of the model&#39;s architecture</span>
  336. <span class="sd"> @param model: model to load pretrinaed weights for</span>
  337. <span class="sd"> @param pretrained_weights: path tp pretrained weights</span>
  338. <span class="sd"> @return: None</span>
  339. <span class="sd"> &quot;&quot;&quot;</span>
  340. <span class="n">map_location</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;cpu&quot;</span><span class="p">)</span>
  341. <span class="n">pretrained_state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">pretrained_weights</span><span class="p">,</span> <span class="n">map_location</span><span class="o">=</span><span class="n">map_location</span><span class="p">)</span>
  342. <span class="n">_load_weights</span><span class="p">(</span><span class="n">architecture</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">pretrained_state_dict</span><span class="p">)</span>
  343. </pre></div>
  344. </div>
  345. </div>
  346. <footer>
  347. <hr/>
  348. <div role="contentinfo">
  349. <p>&#169; Copyright 2021, SuperGradients team.</p>
  350. </div>
  351. Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
  352. <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
  353. provided by <a href="https://readthedocs.org">Read the Docs</a>.
  354. </footer>
  355. </div>
  356. </div>
  357. </section>
  358. </div>
  359. <script>
  360. jQuery(function () {
  361. SphinxRtdTheme.Navigation.enable(true);
  362. });
  363. </script>
  364. </body>
  365. </html>
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...