Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

export.py 7.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Export a PyTorch model to TorchScript, ONNX, CoreML formats
  4. Usage:
  5. $ python path/to/export.py --weights yolov5s.pt --img 640 --batch 1
  6. """
  7. import argparse
  8. import sys
  9. import time
  10. from pathlib import Path
  11. import torch
  12. import torch.nn as nn
  13. from torch.utils.mobile_optimizer import optimize_for_mobile
  14. FILE = Path(__file__).absolute()
  15. sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
  16. from models.common import Conv
  17. from models.yolo import Detect
  18. from models.experimental import attempt_load
  19. from utils.activations import Hardswish, SiLU
  20. from utils.general import colorstr, check_img_size, check_requirements, file_size, set_logging
  21. from utils.torch_utils import select_device
  22. def export_torchscript(model, img, file, optimize):
  23. # TorchScript model export
  24. prefix = colorstr('TorchScript:')
  25. try:
  26. print(f'\n{prefix} starting export with torch {torch.__version__}...')
  27. f = file.with_suffix('.torchscript.pt')
  28. ts = torch.jit.trace(model, img, strict=False)
  29. (optimize_for_mobile(ts) if optimize else ts).save(f)
  30. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  31. return ts
  32. except Exception as e:
  33. print(f'{prefix} export failure: {e}')
  34. def export_onnx(model, img, file, opset, train, dynamic, simplify):
  35. # ONNX model export
  36. prefix = colorstr('ONNX:')
  37. try:
  38. check_requirements(('onnx',))
  39. import onnx
  40. print(f'\n{prefix} starting export with onnx {onnx.__version__}...')
  41. f = file.with_suffix('.onnx')
  42. torch.onnx.export(model, img, f, verbose=False, opset_version=opset,
  43. training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
  44. do_constant_folding=not train,
  45. input_names=['images'],
  46. output_names=['output'],
  47. dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
  48. 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
  49. } if dynamic else None)
  50. # Checks
  51. model_onnx = onnx.load(f) # load onnx model
  52. onnx.checker.check_model(model_onnx) # check onnx model
  53. # print(onnx.helper.printable_graph(model_onnx.graph)) # print
  54. # Simplify
  55. if simplify:
  56. try:
  57. check_requirements(('onnx-simplifier',))
  58. import onnxsim
  59. print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
  60. model_onnx, check = onnxsim.simplify(
  61. model_onnx,
  62. dynamic_input_shape=dynamic,
  63. input_shapes={'images': list(img.shape)} if dynamic else None)
  64. assert check, 'assert check failed'
  65. onnx.save(model_onnx, f)
  66. except Exception as e:
  67. print(f'{prefix} simplifier failure: {e}')
  68. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  69. print(f"{prefix} run --dynamic ONNX model inference with: 'python detect.py --weights {f}'")
  70. except Exception as e:
  71. print(f'{prefix} export failure: {e}')
  72. def export_coreml(model, img, file):
  73. # CoreML model export
  74. prefix = colorstr('CoreML:')
  75. try:
  76. check_requirements(('coremltools',))
  77. import coremltools as ct
  78. print(f'\n{prefix} starting export with coremltools {ct.__version__}...')
  79. f = file.with_suffix('.mlmodel')
  80. model.train() # CoreML exports should be placed in model.train() mode
  81. ts = torch.jit.trace(model, img, strict=False) # TorchScript model
  82. model = ct.convert(ts, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
  83. model.save(f)
  84. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  85. except Exception as e:
  86. print(f'\n{prefix} export failure: {e}')
  87. def run(weights='./yolov5s.pt', # weights path
  88. img_size=(640, 640), # image (height, width)
  89. batch_size=1, # batch size
  90. device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
  91. include=('torchscript', 'onnx', 'coreml'), # include formats
  92. half=False, # FP16 half-precision export
  93. inplace=False, # set YOLOv5 Detect() inplace=True
  94. train=False, # model.train() mode
  95. optimize=False, # TorchScript: optimize for mobile
  96. dynamic=False, # ONNX: dynamic axes
  97. simplify=False, # ONNX: simplify model
  98. opset=12, # ONNX: opset version
  99. ):
  100. t = time.time()
  101. include = [x.lower() for x in include]
  102. img_size *= 2 if len(img_size) == 1 else 1 # expand
  103. file = Path(weights)
  104. # Load PyTorch model
  105. device = select_device(device)
  106. assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
  107. model = attempt_load(weights, map_location=device) # load FP32 model
  108. names = model.names
  109. # Input
  110. gs = int(max(model.stride)) # grid size (max stride)
  111. img_size = [check_img_size(x, gs) for x in img_size] # verify img_size are gs-multiples
  112. img = torch.zeros(batch_size, 3, *img_size).to(device) # image size(1,3,320,192) iDetection
  113. # Update model
  114. if half:
  115. img, model = img.half(), model.half() # to FP16
  116. model.train() if train else model.eval() # training mode = no Detect() layer grid construction
  117. for k, m in model.named_modules():
  118. if isinstance(m, Conv): # assign export-friendly activations
  119. if isinstance(m.act, nn.Hardswish):
  120. m.act = Hardswish()
  121. elif isinstance(m.act, nn.SiLU):
  122. m.act = SiLU()
  123. elif isinstance(m, Detect):
  124. m.inplace = inplace
  125. m.onnx_dynamic = dynamic
  126. # m.forward = m.forward_export # assign forward (optional)
  127. for _ in range(2):
  128. y = model(img) # dry runs
  129. print(f"\n{colorstr('PyTorch:')} starting from {weights} ({file_size(weights):.1f} MB)")
  130. # Exports
  131. if 'torchscript' in include:
  132. export_torchscript(model, img, file, optimize)
  133. if 'onnx' in include:
  134. export_onnx(model, img, file, opset, train, dynamic, simplify)
  135. if 'coreml' in include:
  136. export_coreml(model, img, file)
  137. # Finish
  138. print(f'\nExport complete ({time.time() - t:.2f}s)'
  139. f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
  140. f'\nVisualize with https://netron.app')
  141. def parse_opt():
  142. parser = argparse.ArgumentParser()
  143. parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
  144. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image (height, width)')
  145. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  146. parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  147. parser.add_argument('--include', nargs='+', default=['torchscript', 'onnx', 'coreml'], help='include formats')
  148. parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
  149. parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
  150. parser.add_argument('--train', action='store_true', help='model.train() mode')
  151. parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
  152. parser.add_argument('--dynamic', action='store_true', help='ONNX: dynamic axes')
  153. parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
  154. parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
  155. opt = parser.parse_args()
  156. return opt
  157. def main(opt):
  158. set_logging()
  159. print(colorstr('export: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  160. run(**vars(opt))
  161. if __name__ == "__main__":
  162. opt = parse_opt()
  163. main(opt)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...