1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
|
- <!DOCTYPE html>
- <html class="writer-html5" lang="en" >
- <head>
- <meta charset="utf-8" />
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
- <title>super_gradients.training.utils.weight_averaging_utils — SuperGradients 1.0 documentation</title>
- <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
- <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
- <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
- <!--[if lt IE 9]>
- <script src="../../../../_static/js/html5shiv.min.js"></script>
- <![endif]-->
-
- <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
- <script src="../../../../_static/jquery.js"></script>
- <script src="../../../../_static/underscore.js"></script>
- <script src="../../../../_static/doctools.js"></script>
- <script src="../../../../_static/js/theme.js"></script>
- <link rel="index" title="Index" href="../../../../genindex.html" />
- <link rel="search" title="Search" href="../../../../search.html" />
- </head>
- <body class="wy-body-for-nav">
- <div class="wy-grid-for-nav">
- <nav data-toggle="wy-nav-shift" class="wy-nav-side">
- <div class="wy-side-scroll">
- <div class="wy-side-nav-search" >
- <a href="../../../../index.html" class="icon icon-home"> SuperGradients
- </a>
- <div role="search">
- <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
- <input type="text" name="q" placeholder="Search docs" />
- <input type="hidden" name="check_keywords" value="yes" />
- <input type="hidden" name="area" value="default" />
- </form>
- </div>
- </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
- <p class="caption"><span class="caption-text">Welcome To SuperGradients</span></p>
- <ul>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">Fill our 4-question quick survey! We will raffle free SuperGradients swag between those who will participate -> Fill Survey</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html#supergradients">SuperGradients</a></li>
- </ul>
- <p class="caption"><span class="caption-text">Technical Documentation</span></p>
- <ul>
- <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common package</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training package</a></li>
- </ul>
- <p class="caption"><span class="caption-text">User Guide</span></p>
- <ul>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html">What is SuperGradients?</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#introducing-the-supergradients-library">Introducing the SuperGradients library</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#installation">Installation</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#integrating-your-training-code-complete-walkthrough">Integrating Your Training Code - Complete Walkthrough</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#training-parameters">Training Parameters</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#logs-and-checkpoints">Logs and Checkpoints</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#dataset-parameters">Dataset Parameters</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#network-architectures">Network Architectures</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#pretrained-models">Pretrained Models</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#how-to-reproduce-our-training-recipes">How To Reproduce Our Training Recipes</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#professional-tools-integration">Professional Tools Integration</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../user_guide.html#supergradients-faq">SuperGradients FAQ</a></li>
- </ul>
- </div>
- </div>
- </nav>
- <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
- <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
- <a href="../../../../index.html">SuperGradients</a>
- </nav>
- <div class="wy-nav-content">
- <div class="rst-content">
- <div role="navigation" aria-label="Page navigation">
- <ul class="wy-breadcrumbs">
- <li><a href="../../../../index.html" class="icon icon-home"></a> »</li>
- <li><a href="../../../index.html">Module code</a> »</li>
- <li>super_gradients.training.utils.weight_averaging_utils</li>
- <li class="wy-breadcrumbs-aside">
- </li>
- </ul>
- <hr/>
- </div>
- <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
- <div itemprop="articleBody">
-
- <h1>Source code for super_gradients.training.utils.weight_averaging_utils</h1><div class="highlight"><pre>
- <span></span><span class="kn">import</span> <span class="nn">os</span>
- <span class="kn">import</span> <span class="nn">torch</span>
- <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
- <span class="kn">import</span> <span class="nn">pkg_resources</span>
- <span class="kn">from</span> <span class="nn">super_gradients.training</span> <span class="kn">import</span> <span class="n">utils</span> <span class="k">as</span> <span class="n">core_utils</span>
- <span class="kn">from</span> <span class="nn">super_gradients.training.utils.utils</span> <span class="kn">import</span> <span class="n">move_state_dict_to_device</span>
- <div class="viewcode-block" id="ModelWeightAveraging"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.weight_averaging_utils.ModelWeightAveraging">[docs]</a><span class="k">class</span> <span class="nc">ModelWeightAveraging</span><span class="p">:</span>
- <span class="sd">"""</span>
- <span class="sd"> Utils class for managing the averaging of the best several snapshots into a single model.</span>
- <span class="sd"> A snapshot dictionary file and the average model will be saved / updated at every epoch and evaluated only when</span>
- <span class="sd"> training is completed. The snapshot file will only be deleted upon completing the training.</span>
- <span class="sd"> The snapshot dict will be managed on cpu.</span>
- <span class="sd"> """</span>
- <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ckpt_dir</span><span class="p">,</span>
- <span class="n">greater_is_better</span><span class="p">,</span>
- <span class="n">source_ckpt_folder_name</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">metric_to_watch</span><span class="o">=</span><span class="s1">'acc'</span><span class="p">,</span>
- <span class="n">metric_idx</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">load_checkpoint</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
- <span class="n">number_of_models_to_average</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
- <span class="n">model_checkpoints_location</span><span class="o">=</span><span class="s1">'local'</span>
- <span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> Init the ModelWeightAveraging</span>
- <span class="sd"> :param checkpoint_dir: the directory where the checkpoints are saved</span>
- <span class="sd"> :param metric_to_watch: monitoring loss or acc, will be identical to that which determines best_model</span>
- <span class="sd"> :param metric_idx:</span>
- <span class="sd"> :param load_checkpoint: whether to load pre-existing snapshot dict.</span>
- <span class="sd"> :param number_of_models_to_average: number of models to average</span>
- <span class="sd"> """</span>
- <span class="k">if</span> <span class="n">source_ckpt_folder_name</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
- <span class="n">source_ckpt_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">source_ckpt_folder_name</span><span class="p">,</span> <span class="s1">'averaging_snapshots.pkl'</span><span class="p">)</span>
- <span class="n">source_ckpt_file</span> <span class="o">=</span> <span class="n">pkg_resources</span><span class="o">.</span><span class="n">resource_filename</span><span class="p">(</span><span class="s1">'checkpoints'</span><span class="p">,</span> <span class="n">source_ckpt_file</span><span class="p">)</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">averaging_snapshots_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">ckpt_dir</span><span class="p">,</span> <span class="s1">'averaging_snapshots.pkl'</span><span class="p">)</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">number_of_models_to_average</span> <span class="o">=</span> <span class="n">number_of_models_to_average</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">metric_to_watch</span> <span class="o">=</span> <span class="n">metric_to_watch</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">metric_idx</span> <span class="o">=</span> <span class="n">metric_idx</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="o">=</span> <span class="n">greater_is_better</span>
- <span class="c1"># if continuing training, copy previous snapshot dict if exist</span>
- <span class="k">if</span> <span class="n">load_checkpoint</span> <span class="ow">and</span> <span class="n">source_ckpt_folder_name</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isfile</span><span class="p">(</span><span class="n">source_ckpt_file</span><span class="p">):</span>
- <span class="n">averaging_snapshots_dict</span> <span class="o">=</span> <span class="n">core_utils</span><span class="o">.</span><span class="n">load_checkpoint</span><span class="p">(</span><span class="n">ckpt_destination_dir</span><span class="o">=</span><span class="n">ckpt_dir</span><span class="p">,</span>
- <span class="n">source_ckpt_folder_name</span><span class="o">=</span><span class="n">source_ckpt_folder_name</span><span class="p">,</span>
- <span class="n">ckpt_filename</span><span class="o">=</span><span class="s2">"averaging_snapshots.pkl"</span><span class="p">,</span>
- <span class="n">load_weights_only</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
- <span class="n">model_checkpoints_location</span><span class="o">=</span><span class="n">model_checkpoints_location</span><span class="p">,</span>
- <span class="n">overwrite_local_ckpt</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
- <span class="k">else</span><span class="p">:</span>
- <span class="n">averaging_snapshots_dict</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'snapshot'</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">):</span> <span class="kc">None</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">number_of_models_to_average</span><span class="p">)}</span>
- <span class="c1"># if metric to watch is acc, hold a zero array, if loss hold inf array</span>
- <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span><span class="p">:</span>
- <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">'snapshots_metric'</span><span class="p">]</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">number_of_models_to_average</span><span class="p">)</span>
- <span class="k">else</span><span class="p">:</span>
- <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">'snapshots_metric'</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">number_of_models_to_average</span><span class="p">)</span>
- <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">averaging_snapshots_dict</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">averaging_snapshots_file</span><span class="p">)</span>
- <div class="viewcode-block" id="ModelWeightAveraging.update_snapshots_dict"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.weight_averaging_utils.ModelWeightAveraging.update_snapshots_dict">[docs]</a> <span class="k">def</span> <span class="nf">update_snapshots_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> Update the snapshot dict and returns the updated average model for saving</span>
- <span class="sd"> :param model: the latest model</span>
- <span class="sd"> :param validation_results_tuple: performance of the latest model</span>
- <span class="sd"> """</span>
- <span class="n">averaging_snapshots_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_averaging_snapshots_dict</span><span class="p">()</span>
- <span class="c1"># IF CURRENT MODEL IS BETTER, TAKING HIS PLACE IN ACC LIST AND OVERWRITE THE NEW AVERAGE</span>
- <span class="n">require_update</span><span class="p">,</span> <span class="n">update_ind</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_is_better</span><span class="p">(</span><span class="n">averaging_snapshots_dict</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">)</span>
- <span class="k">if</span> <span class="n">require_update</span><span class="p">:</span>
- <span class="c1"># moving state dict to cpu</span>
- <span class="n">new_sd</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()</span>
- <span class="n">new_sd</span> <span class="o">=</span> <span class="n">move_state_dict_to_device</span><span class="p">(</span><span class="n">new_sd</span><span class="p">,</span> <span class="s1">'cpu'</span><span class="p">)</span>
- <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">'snapshot'</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">update_ind</span><span class="p">)]</span> <span class="o">=</span> <span class="n">new_sd</span>
- <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">'snapshots_metric'</span><span class="p">][</span><span class="n">update_ind</span><span class="p">]</span> <span class="o">=</span> <span class="n">validation_results_tuple</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_idx</span><span class="p">]</span>
- <span class="k">return</span> <span class="n">averaging_snapshots_dict</span></div>
- <div class="viewcode-block" id="ModelWeightAveraging.get_average_model"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.weight_averaging_utils.ModelWeightAveraging.get_average_model">[docs]</a> <span class="k">def</span> <span class="nf">get_average_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> Returns the averaged model</span>
- <span class="sd"> :param model: will be used to determine arch</span>
- <span class="sd"> :param validation_results_tuple: if provided, will update the average model before returning</span>
- <span class="sd"> :param target_device: if provided, return sd on target device</span>
- <span class="sd"> """</span>
- <span class="c1"># If validation tuple is provided, update the average model</span>
- <span class="k">if</span> <span class="n">validation_results_tuple</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
- <span class="n">averaging_snapshots_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">update_snapshots_dict</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">)</span>
- <span class="k">else</span><span class="p">:</span>
- <span class="n">averaging_snapshots_dict</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_get_averaging_snapshots_dict</span><span class="p">()</span>
- <span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">averaging_snapshots_dict</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">averaging_snapshots_file</span><span class="p">)</span>
- <span class="n">average_model_sd</span> <span class="o">=</span> <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">'snapshot0'</span><span class="p">]</span>
- <span class="k">for</span> <span class="n">n_model</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">number_of_models_to_average</span><span class="p">):</span>
- <span class="k">if</span> <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">'snapshot'</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">n_model</span><span class="p">)]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
- <span class="n">net_sd</span> <span class="o">=</span> <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">'snapshot'</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">n_model</span><span class="p">)]</span>
- <span class="c1"># USING MOVING AVERAGE</span>
- <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">average_model_sd</span><span class="p">:</span>
- <span class="n">average_model_sd</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">true_divide</span><span class="p">(</span>
- <span class="n">average_model_sd</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">*</span> <span class="n">n_model</span> <span class="o">+</span> <span class="n">net_sd</span><span class="p">[</span><span class="n">key</span><span class="p">],</span>
- <span class="p">(</span><span class="n">n_model</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>
- <span class="k">return</span> <span class="n">average_model_sd</span></div>
- <div class="viewcode-block" id="ModelWeightAveraging.cleanup"><a class="viewcode-back" href="../../../../super_gradients.training.utils.html#super_gradients.training.utils.weight_averaging_utils.ModelWeightAveraging.cleanup">[docs]</a> <span class="k">def</span> <span class="nf">cleanup</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> Delete snapshot file when reaching the last epoch</span>
- <span class="sd"> """</span>
- <span class="n">os</span><span class="o">.</span><span class="n">remove</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">averaging_snapshots_file</span><span class="p">)</span></div>
- <span class="k">def</span> <span class="nf">_is_better</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">averaging_snapshots_dict</span><span class="p">,</span> <span class="n">validation_results_tuple</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> Determines if the new model is better according to the specified metrics</span>
- <span class="sd"> :param averaging_snapshots_dict: snapshot dict</span>
- <span class="sd"> :param validation_results_tuple: latest model performance</span>
- <span class="sd"> """</span>
- <span class="n">snapshot_metric_array</span> <span class="o">=</span> <span class="n">averaging_snapshots_dict</span><span class="p">[</span><span class="s1">'snapshots_metric'</span><span class="p">]</span>
- <span class="n">val</span> <span class="o">=</span> <span class="n">validation_results_tuple</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">metric_idx</span><span class="p">]</span>
- <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span><span class="p">:</span>
- <span class="n">update_ind</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmin</span><span class="p">(</span><span class="n">snapshot_metric_array</span><span class="p">)</span>
- <span class="k">else</span><span class="p">:</span>
- <span class="n">update_ind</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">snapshot_metric_array</span><span class="p">)</span>
- <span class="k">if</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="ow">and</span> <span class="n">val</span> <span class="o">></span> <span class="n">snapshot_metric_array</span><span class="p">[</span><span class="n">update_ind</span><span class="p">])</span> <span class="ow">or</span> <span class="p">(</span>
- <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">greater_is_better</span> <span class="ow">and</span> <span class="n">val</span> <span class="o"><</span> <span class="n">snapshot_metric_array</span><span class="p">[</span><span class="n">update_ind</span><span class="p">]):</span>
- <span class="k">return</span> <span class="kc">True</span><span class="p">,</span> <span class="n">update_ind</span>
- <span class="k">return</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">None</span>
- <span class="k">def</span> <span class="nf">_get_averaging_snapshots_dict</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
- <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">averaging_snapshots_file</span><span class="p">)</span></div>
- </pre></div>
- </div>
- </div>
- <footer>
- <hr/>
- <div role="contentinfo">
- <p>© Copyright 2021, SuperGradients team.</p>
- </div>
- Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
- <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
- provided by <a href="https://readthedocs.org">Read the Docs</a>.
-
- </footer>
- </div>
- </div>
- </section>
- </div>
- <script>
- jQuery(function () {
- SphinxRtdTheme.Navigation.enable(true);
- });
- </script>
- </body>
- </html>
|