1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
|
- <!DOCTYPE html>
- <html class="writer-html5" lang="en" >
- <head>
- <meta charset="utf-8" />
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
- <title>super_gradients.training.losses.ssd_loss — SuperGradients 1.0 documentation</title>
- <link rel="stylesheet" href="../../../../_static/pygments.css" type="text/css" />
- <link rel="stylesheet" href="../../../../_static/css/theme.css" type="text/css" />
- <link rel="stylesheet" href="../../../../_static/graphviz.css" type="text/css" />
- <!--[if lt IE 9]>
- <script src="../../../../_static/js/html5shiv.min.js"></script>
- <![endif]-->
-
- <script data-url_root="../../../../" id="documentation_options" src="../../../../_static/documentation_options.js"></script>
- <script src="../../../../_static/jquery.js"></script>
- <script src="../../../../_static/underscore.js"></script>
- <script src="../../../../_static/doctools.js"></script>
- <script src="../../../../_static/js/theme.js"></script>
- <link rel="index" title="Index" href="../../../../genindex.html" />
- <link rel="search" title="Search" href="../../../../search.html" />
- </head>
- <body class="wy-body-for-nav">
- <div class="wy-grid-for-nav">
- <nav data-toggle="wy-nav-shift" class="wy-nav-side">
- <div class="wy-side-scroll">
- <div class="wy-side-nav-search" >
- <a href="../../../../index.html" class="icon icon-home"> SuperGradients
- </a>
- <div role="search">
- <form id="rtd-search-form" class="wy-form" action="../../../../search.html" method="get">
- <input type="text" name="q" placeholder="Search docs" />
- <input type="hidden" name="check_keywords" value="yes" />
- <input type="hidden" name="area" value="default" />
- </form>
- </div>
- </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
- <p class="caption"><span class="caption-text">SuperGradients</span></p>
- <ul>
- <li class="toctree-l1"><a class="reference internal" href="../../../../welcome.html">SuperGradients</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.common.html">Common</a></li>
- <li class="toctree-l1"><a class="reference internal" href="../../../../super_gradients.training.html">Training</a></li>
- </ul>
- </div>
- </div>
- </nav>
- <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
- <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
- <a href="../../../../index.html">SuperGradients</a>
- </nav>
- <div class="wy-nav-content">
- <div class="rst-content">
- <div role="navigation" aria-label="Page navigation">
- <ul class="wy-breadcrumbs">
- <li><a href="../../../../index.html" class="icon icon-home"></a> »</li>
- <li><a href="../../../index.html">Module code</a> »</li>
- <li>super_gradients.training.losses.ssd_loss</li>
- <li class="wy-breadcrumbs-aside">
- </li>
- </ul>
- <hr/>
- </div>
- <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
- <div itemprop="articleBody">
-
- <h1>Source code for super_gradients.training.losses.ssd_loss</h1><div class="highlight"><pre>
- <span></span><span class="kn">import</span> <span class="nn">torch</span>
- <span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
- <span class="kn">from</span> <span class="nn">torch.nn.modules.loss</span> <span class="kn">import</span> <span class="n">_Loss</span>
- <span class="kn">from</span> <span class="nn">super_gradients.training.utils.detection_utils</span> <span class="kn">import</span> <span class="n">calculate_bbox_iou_matrix</span>
- <span class="kn">from</span> <span class="nn">super_gradients.training.utils.ssd_utils</span> <span class="kn">import</span> <span class="n">DefaultBoxes</span>
- <div class="viewcode-block" id="SSDLoss"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ssd_loss.SSDLoss">[docs]</a><span class="k">class</span> <span class="nc">SSDLoss</span><span class="p">(</span><span class="n">_Loss</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> Implements the loss as the sum of the followings:</span>
- <span class="sd"> 1. Confidence Loss: All labels, with hard negative mining</span>
- <span class="sd"> 2. Localization Loss: Only on positive labels</span>
- <span class="sd"> """</span>
- <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dboxes</span><span class="p">:</span> <span class="n">DefaultBoxes</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">):</span>
- <span class="nb">super</span><span class="p">(</span><span class="n">SSDLoss</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">scale_xy</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">dboxes</span><span class="o">.</span><span class="n">scale_xy</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">scale_wh</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">dboxes</span><span class="o">.</span><span class="n">scale_wh</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">sl1_loss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">SmoothL1Loss</span><span class="p">(</span><span class="n">reduce</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">dboxes</span><span class="p">(</span><span class="n">order</span><span class="o">=</span><span class="s2">"xywh"</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
- <span class="bp">self</span><span class="o">.</span><span class="n">con_loss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">(</span><span class="n">reduce</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
- <span class="k">def</span> <span class="nf">_norm_relative_bbox</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loc</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> convert bbox locations into relative locations (relative to the dboxes) and normalized by w,h</span>
- <span class="sd"> :param loc a tensor of shape [batch, 4, num_boxes]</span>
- <span class="sd"> """</span>
- <span class="n">gxy</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_xy</span> <span class="o">*</span> <span class="p">(</span><span class="n">loc</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">,</span> <span class="p">:]</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">,</span> <span class="p">:])</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:,</span> <span class="p">]</span>
- <span class="n">gwh</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale_wh</span> <span class="o">*</span> <span class="p">(</span><span class="n">loc</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:,</span> <span class="p">:]</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:,</span> <span class="p">:])</span><span class="o">.</span><span class="n">log</span><span class="p">()</span>
- <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">gxy</span><span class="p">,</span> <span class="n">gwh</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
- <div class="viewcode-block" id="SSDLoss.match_dboxes"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ssd_loss.SSDLoss.match_dboxes">[docs]</a> <span class="k">def</span> <span class="nf">match_dboxes</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> convert ground truth boxes into a tensor with the same size as dboxes. each gt bbox is matched to every</span>
- <span class="sd"> destination box which overlaps it over 0.5 (IoU). so some gt bboxes can be duplicated to a few destination boxes</span>
- <span class="sd"> :param targets: a tensor containing the boxes for a single image. shape [num_boxes, 5] (x,y,w,h,label)</span>
- <span class="sd"> :return: two tensors</span>
- <span class="sd"> boxes - shape of dboxes [4, num_dboxes] (x,y,w,h)</span>
- <span class="sd"> labels - sahpe [num_dboxes]</span>
- <span class="sd"> """</span>
- <span class="n">target_locations</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span>
- <span class="n">target_labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]))</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
- <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
- <span class="n">boxes</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:]</span>
- <span class="n">ious</span> <span class="o">=</span> <span class="n">calculate_bbox_iou_matrix</span><span class="p">(</span><span class="n">boxes</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span><span class="o">.</span><span class="n">T</span><span class="p">,</span> <span class="n">x1y1x2y2</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
- <span class="n">values</span><span class="p">,</span> <span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">ious</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
- <span class="n">mask</span> <span class="o">=</span> <span class="n">values</span> <span class="o">></span> <span class="mf">0.5</span>
- <span class="n">target_locations</span><span class="p">[:,</span> <span class="n">mask</span><span class="p">]</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="n">indices</span><span class="p">[</span><span class="n">mask</span><span class="p">],</span> <span class="mi">2</span><span class="p">:]</span><span class="o">.</span><span class="n">T</span>
- <span class="n">target_labels</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span> <span class="o">=</span> <span class="n">targets</span><span class="p">[</span><span class="n">indices</span><span class="p">[</span><span class="n">mask</span><span class="p">],</span> <span class="mi">1</span><span class="p">]</span>
- <span class="k">return</span> <span class="n">target_locations</span><span class="p">,</span> <span class="n">target_labels</span></div>
- <div class="viewcode-block" id="SSDLoss.forward"><a class="viewcode-back" href="../../../../super_gradients.training.losses.html#super_gradients.training.losses.ssd_loss.SSDLoss.forward">[docs]</a> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">predictions</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
- <span class="sd">"""</span>
- <span class="sd"> Compute the loss</span>
- <span class="sd"> :param predictions - predictions tensor coming from the network. shape [N, num_classes+4, num_dboxes]</span>
- <span class="sd"> were the first four items are (x,y,w,h) and the rest are class confidence</span>
- <span class="sd"> :param targets - targets for the batch. [num targets, 6] (index in batch, label, x,y,w,h)</span>
- <span class="sd"> """</span>
- <span class="n">batch_target_locations</span> <span class="o">=</span> <span class="p">[]</span>
- <span class="n">batch_target_labels</span> <span class="o">=</span> <span class="p">[]</span>
- <span class="p">(</span><span class="n">ploc</span><span class="p">,</span> <span class="n">plabel</span><span class="p">)</span> <span class="o">=</span> <span class="n">predictions</span>
- <span class="n">targets</span> <span class="o">=</span> <span class="n">targets</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dboxes</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
- <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ploc</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span>
- <span class="n">target_locations</span><span class="p">,</span> <span class="n">target_labels</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">match_dboxes</span><span class="p">(</span><span class="n">targets</span><span class="p">[</span><span class="n">targets</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">i</span><span class="p">])</span>
- <span class="n">batch_target_locations</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">target_locations</span><span class="p">)</span>
- <span class="n">batch_target_labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">target_labels</span><span class="p">)</span>
- <span class="n">batch_target_locations</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">batch_target_locations</span><span class="p">)</span>
- <span class="n">batch_target_labels</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">batch_target_labels</span><span class="p">)</span><span class="o">.</span><span class="n">type</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">)</span>
- <span class="n">mask</span> <span class="o">=</span> <span class="n">batch_target_labels</span> <span class="o">></span> <span class="mi">0</span>
- <span class="n">pos_num</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
- <span class="n">vec_gd</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_norm_relative_bbox</span><span class="p">(</span><span class="n">batch_target_locations</span><span class="p">)</span>
- <span class="c1"># SUM ON FOUR COORDINATES, AND MASK</span>
- <span class="n">sl1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sl1_loss</span><span class="p">(</span><span class="n">ploc</span><span class="p">,</span> <span class="n">vec_gd</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
- <span class="n">sl1</span> <span class="o">=</span> <span class="p">(</span><span class="n">mask</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">*</span> <span class="n">sl1</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
- <span class="c1"># HARD NEGATIVE MINING</span>
- <span class="n">con</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">con_loss</span><span class="p">(</span><span class="n">plabel</span><span class="p">,</span> <span class="n">batch_target_labels</span><span class="p">)</span>
- <span class="c1"># POSITIVE MASK WILL NEVER SELECTED</span>
- <span class="n">con_neg</span> <span class="o">=</span> <span class="n">con</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
- <span class="n">con_neg</span><span class="p">[</span><span class="n">mask</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span>
- <span class="n">_</span><span class="p">,</span> <span class="n">con_idx</span> <span class="o">=</span> <span class="n">con_neg</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">descending</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
- <span class="n">_</span><span class="p">,</span> <span class="n">con_rank</span> <span class="o">=</span> <span class="n">con_idx</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
- <span class="c1"># NUMBER OF NEGATIVE THREE TIMES POSITIVE</span>
- <span class="n">neg_num</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="mi">3</span> <span class="o">*</span> <span class="n">pos_num</span><span class="p">,</span> <span class="nb">max</span><span class="o">=</span><span class="n">mask</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
- <span class="n">neg_mask</span> <span class="o">=</span> <span class="n">con_rank</span> <span class="o"><</span> <span class="n">neg_num</span>
- <span class="n">closs</span> <span class="o">=</span> <span class="p">(</span><span class="n">con</span> <span class="o">*</span> <span class="p">(</span><span class="n">mask</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="o">+</span> <span class="n">neg_mask</span><span class="o">.</span><span class="n">float</span><span class="p">()))</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
- <span class="c1"># AVOID NO OBJECT DETECTED</span>
- <span class="n">total_loss</span> <span class="o">=</span> <span class="p">(</span><span class="mi">2</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span><span class="p">)</span> <span class="o">*</span> <span class="n">sl1</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">*</span> <span class="n">closs</span>
- <span class="n">num_mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">pos_num</span> <span class="o">></span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
- <span class="n">pos_num</span> <span class="o">=</span> <span class="n">pos_num</span><span class="o">.</span><span class="n">float</span><span class="p">()</span><span class="o">.</span><span class="n">clamp</span><span class="p">(</span><span class="nb">min</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)</span>
- <span class="n">ret</span> <span class="o">=</span> <span class="p">(</span><span class="n">total_loss</span> <span class="o">*</span> <span class="n">num_mask</span> <span class="o">/</span> <span class="n">pos_num</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
- <span class="k">return</span> <span class="n">ret</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">sl1</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">closs</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">ret</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)))</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span></div></div>
- </pre></div>
- </div>
- </div>
- <footer>
- <hr/>
- <div role="contentinfo">
- <p>© Copyright 2021, SuperGradients team.</p>
- </div>
- Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
- <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
- provided by <a href="https://readthedocs.org">Read the Docs</a>.
-
- </footer>
- </div>
- </div>
- </section>
- </div>
- <script>
- jQuery(function () {
- SphinxRtdTheme.Navigation.enable(true);
- });
- </script>
- </body>
- </html>
|