Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

export.py 8.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
  1. """Export a YOLOv5 *.pt model to TorchScript, ONNX, CoreML formats
  2. Usage:
  3. $ python path/to/export.py --weights yolov5s.pt --img 640 --batch 1
  4. """
  5. import argparse
  6. import struct
  7. import sys
  8. import time
  9. from pathlib import Path
  10. import torch
  11. import torch.nn as nn
  12. from torch.utils.mobile_optimizer import optimize_for_mobile
  13. from utils.torch_utils import select_device
  14. FILE = Path(__file__).absolute()
  15. sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
  16. from models.common import Conv
  17. from models.yolo import Detect, attempt_download
  18. from models.experimental import attempt_load
  19. from utils.activations import Hardswish, SiLU
  20. from utils.general import colorstr, check_img_size, check_requirements, file_size, set_logging
  21. from utils.torch_utils import select_device
  22. def export_torchscript(model, img, file, optimize):
  23. # TorchScript model export
  24. prefix = colorstr('TorchScript:')
  25. try:
  26. print(f'\n{prefix} starting export with torch {torch.__version__}...')
  27. f = file.with_suffix('.torchscript.pt')
  28. ts = torch.jit.trace(model, img, strict=False)
  29. (optimize_for_mobile(ts) if optimize else ts).save(f)
  30. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  31. return ts
  32. except Exception as e:
  33. print(f'{prefix} export failure: {e}')
  34. def export_onnx(model, img, file, opset, train, dynamic, simplify):
  35. # ONNX model export
  36. prefix = colorstr('ONNX:')
  37. try:
  38. check_requirements(('onnx', 'onnx-simplifier'))
  39. import onnx
  40. print(f'\n{prefix} starting export with onnx {onnx.__version__}...')
  41. f = file.with_suffix('.onnx')
  42. torch.onnx.export(model, img, f, verbose=False, opset_version=opset,
  43. training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
  44. do_constant_folding=not train,
  45. input_names=['images'],
  46. output_names=['output'],
  47. dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
  48. 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
  49. } if dynamic else None)
  50. # Checks
  51. model_onnx = onnx.load(f) # load onnx model
  52. onnx.checker.check_model(model_onnx) # check onnx model
  53. # print(onnx.helper.printable_graph(model_onnx.graph)) # print
  54. # Simplify
  55. if simplify:
  56. try:
  57. import onnxsim
  58. print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
  59. model_onnx, check = onnxsim.simplify(
  60. model_onnx,
  61. dynamic_input_shape=dynamic,
  62. input_shapes={'images': list(img.shape)} if dynamic else None)
  63. assert check, 'assert check failed'
  64. onnx.save(model_onnx, f)
  65. except Exception as e:
  66. print(f'{prefix} simplifier failure: {e}')
  67. print(f)
  68. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  69. print(f"{prefix} run --dynamic ONNX model inference with detect.py: 'python detect.py --weights {f}'")
  70. except Exception as e:
  71. print(f'{prefix} export failure: {e}')
  72. def export_coreml(model, img, file):
  73. # CoreML model export
  74. prefix = colorstr('CoreML:')
  75. try:
  76. import coremltools as ct
  77. print(f'\n{prefix} starting export with coremltools {ct.__version__}...')
  78. f = file.with_suffix('.mlmodel')
  79. model.train() # CoreML exports should be placed in model.train() mode
  80. ts = torch.jit.trace(model, img, strict=False) # TorchScript model
  81. model = ct.convert(ts, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
  82. model.save(f)
  83. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  84. except Exception as e:
  85. print(f'\n{prefix} export failure: {e}')
  86. def export_wts(weights, file):
  87. # export for https://github.com/wang-xinyu/tensorrtx/tree/master/yolov5
  88. # credits: Wang-Xinyu
  89. attempt_download(weights)
  90. weights_file = Path(weights)
  91. file_wts = Path(weights_file.parents[0], str(weights_file.stem).strip().replace("'", '') + '.wts')
  92. model = torch.load(weights, map_location='cpu')['model'].float().eval() # load to FP32
  93. with open(file_wts, 'w') as f:
  94. f.write('{}\n'.format(len(model.state_dict().keys())))
  95. for k, v in model.state_dict().items():
  96. vr = v.reshape(-1).cpu().numpy()
  97. f.write('{} {} '.format(k, len(vr)))
  98. for vv in vr:
  99. f.write(' ')
  100. f.write(struct.pack('>f', float(vv)).hex())
  101. f.write('\n')
  102. def run(weights='./yolov5s.pt', # weights path
  103. img_size=(640, 640), # image (height, width)
  104. batch_size=1, # batch size
  105. device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
  106. include=('torchscript', 'onnx', 'coreml', 'wts'), # include formats
  107. half=False, # FP16 half-precision export
  108. inplace=False, # set YOLOv5 Detect() inplace=True
  109. train=False, # model.train() mode
  110. optimize=False, # TorchScript: optimize for mobile
  111. dynamic=False, # ONNX: dynamic axes
  112. simplify=False, # ONNX: simplify model
  113. opset=12, # ONNX: opset version
  114. ):
  115. t = time.time()
  116. include = [x.lower() for x in include]
  117. img_size *= 2 if len(img_size) == 1 else 1 # expand
  118. file = Path(weights)
  119. # Load PyTorch model
  120. device = select_device(device)
  121. assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
  122. model = attempt_load(weights, map_location=device) # load FP32 model
  123. names = model.names
  124. # Input
  125. gs = int(max(model.stride)) # grid size (max stride)
  126. img_size = [check_img_size(x, gs) for x in img_size] # verify img_size are gs-multiples
  127. img = torch.zeros(batch_size, 3, *img_size).to(device) # image size(1,3,320,192) iDetection
  128. # Update model
  129. if half:
  130. img, model = img.half(), model.half() # to FP16
  131. model.train() if train else model.eval() # training mode = no Detect() layer grid construction
  132. for k, m in model.named_modules():
  133. if isinstance(m, Conv): # assign export-friendly activations
  134. if isinstance(m.act, nn.Hardswish):
  135. m.act = Hardswish()
  136. elif isinstance(m.act, nn.SiLU):
  137. m.act = SiLU()
  138. elif isinstance(m, Detect):
  139. m.inplace = inplace
  140. m.onnx_dynamic = dynamic
  141. # m.forward = m.forward_export # assign forward (optional)
  142. for _ in range(2):
  143. y = model(img) # dry runs
  144. print(f"\n{colorstr('PyTorch:')} starting from {weights} ({file_size(weights):.1f} MB)")
  145. # Exports
  146. if 'torchscript' in include:
  147. export_torchscript(model, img, file, optimize)
  148. if 'onnx' in include:
  149. export_onnx(model, img, file, opset, train, dynamic, simplify)
  150. if 'coreml' in include:
  151. export_coreml(model, img, file)
  152. if 'wts' in include:
  153. export_wts(weights, file)
  154. # Finish
  155. print(f'\nExport complete ({time.time() - t:.2f}s)'
  156. f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
  157. f'\nVisualize with https://netron.app')
  158. def parse_opt():
  159. parser = argparse.ArgumentParser()
  160. parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
  161. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image (height, width)')
  162. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  163. parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  164. parser.add_argument('--include', nargs='+', default=['torchscript', 'onnx', 'coreml', 'wts'], help='include formats')
  165. parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
  166. parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
  167. parser.add_argument('--train', action='store_true', help='model.train() mode')
  168. parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
  169. parser.add_argument('--dynamic', action='store_true', help='ONNX: dynamic axes')
  170. parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
  171. parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
  172. opt = parser.parse_args()
  173. return opt
  174. def main(opt):
  175. set_logging()
  176. print(colorstr('export: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  177. run(**vars(opt))
  178. if __name__ == "__main__":
  179. opt = parse_opt()
  180. main(opt)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...