Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

export.py 7.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
  1. """Export a YOLOv5 *.pt model to TorchScript, ONNX, CoreML formats
  2. Usage:
  3. $ python path/to/export.py --weights yolov5s.pt --img 640 --batch 1
  4. """
  5. import argparse
  6. import sys
  7. import time
  8. from pathlib import Path
  9. import torch
  10. import torch.nn as nn
  11. from torch.utils.mobile_optimizer import optimize_for_mobile
  12. FILE = Path(__file__).absolute()
  13. sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
  14. from models.common import Conv
  15. from models.yolo import Detect
  16. from models.experimental import attempt_load
  17. from utils.activations import Hardswish, SiLU
  18. from utils.general import colorstr, check_img_size, check_requirements, file_size, set_logging
  19. from utils.torch_utils import select_device
  20. def run(weights='./yolov5s.pt', # weights path
  21. img_size=(640, 640), # image (height, width)
  22. batch_size=1, # batch size
  23. device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
  24. include=('torchscript', 'onnx', 'coreml'), # include formats
  25. half=False, # FP16 half-precision export
  26. inplace=False, # set YOLOv5 Detect() inplace=True
  27. train=False, # model.train() mode
  28. optimize=False, # TorchScript: optimize for mobile
  29. dynamic=False, # ONNX: dynamic axes
  30. simplify=False, # ONNX: simplify model
  31. opset_version=12, # ONNX: opset version
  32. ):
  33. t = time.time()
  34. include = [x.lower() for x in include]
  35. img_size *= 2 if len(img_size) == 1 else 1 # expand
  36. # Load PyTorch model
  37. device = select_device(device)
  38. assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
  39. model = attempt_load(weights, map_location=device) # load FP32 model
  40. labels = model.names
  41. # Input
  42. gs = int(max(model.stride)) # grid size (max stride)
  43. img_size = [check_img_size(x, gs) for x in img_size] # verify img_size are gs-multiples
  44. img = torch.zeros(batch_size, 3, *img_size).to(device) # image size(1,3,320,192) iDetection
  45. # Update model
  46. if half:
  47. img, model = img.half(), model.half() # to FP16
  48. model.train() if train else model.eval() # training mode = no Detect() layer grid construction
  49. for k, m in model.named_modules():
  50. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
  51. if isinstance(m, Conv): # assign export-friendly activations
  52. if isinstance(m.act, nn.Hardswish):
  53. m.act = Hardswish()
  54. elif isinstance(m.act, nn.SiLU):
  55. m.act = SiLU()
  56. elif isinstance(m, Detect):
  57. m.inplace = inplace
  58. m.onnx_dynamic = dynamic
  59. # m.forward = m.forward_export # assign forward (optional)
  60. for _ in range(2):
  61. y = model(img) # dry runs
  62. print(f"\n{colorstr('PyTorch:')} starting from {weights} ({file_size(weights):.1f} MB)")
  63. # TorchScript export -----------------------------------------------------------------------------------------------
  64. if 'torchscript' in include or 'coreml' in include:
  65. prefix = colorstr('TorchScript:')
  66. try:
  67. print(f'\n{prefix} starting export with torch {torch.__version__}...')
  68. f = weights.replace('.pt', '.torchscript.pt') # filename
  69. ts = torch.jit.trace(model, img, strict=False)
  70. (optimize_for_mobile(ts) if optimize else ts).save(f)
  71. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  72. except Exception as e:
  73. print(f'{prefix} export failure: {e}')
  74. # ONNX export ------------------------------------------------------------------------------------------------------
  75. if 'onnx' in include:
  76. prefix = colorstr('ONNX:')
  77. try:
  78. import onnx
  79. print(f'{prefix} starting export with onnx {onnx.__version__}...')
  80. f = weights.replace('.pt', '.onnx') # filename
  81. torch.onnx.export(model, img, f, verbose=False, opset_version=opset_version,
  82. training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
  83. do_constant_folding=not train,
  84. input_names=['images'],
  85. output_names=['output'],
  86. dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
  87. 'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
  88. } if dynamic else None)
  89. # Checks
  90. model_onnx = onnx.load(f) # load onnx model
  91. onnx.checker.check_model(model_onnx) # check onnx model
  92. # print(onnx.helper.printable_graph(model_onnx.graph)) # print
  93. # Simplify
  94. if simplify:
  95. try:
  96. check_requirements(['onnx-simplifier'])
  97. import onnxsim
  98. print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
  99. model_onnx, check = onnxsim.simplify(
  100. model_onnx,
  101. dynamic_input_shape=dynamic,
  102. input_shapes={'images': list(img.shape)} if dynamic else None)
  103. assert check, 'assert check failed'
  104. onnx.save(model_onnx, f)
  105. except Exception as e:
  106. print(f'{prefix} simplifier failure: {e}')
  107. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  108. except Exception as e:
  109. print(f'{prefix} export failure: {e}')
  110. # CoreML export ----------------------------------------------------------------------------------------------------
  111. if 'coreml' in include:
  112. prefix = colorstr('CoreML:')
  113. try:
  114. import coremltools as ct
  115. print(f'{prefix} starting export with coremltools {ct.__version__}...')
  116. assert train, 'CoreML exports should be placed in model.train() mode with `python export.py --train`'
  117. model = ct.convert(ts, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
  118. f = weights.replace('.pt', '.mlmodel') # filename
  119. model.save(f)
  120. print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
  121. except Exception as e:
  122. print(f'{prefix} export failure: {e}')
  123. # Finish
  124. print(f'\nExport complete ({time.time() - t:.2f}s). Visualize with https://github.com/lutzroeder/netron.')
  125. def parse_opt():
  126. parser = argparse.ArgumentParser()
  127. parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
  128. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image (height, width)')
  129. parser.add_argument('--batch-size', type=int, default=1, help='batch size')
  130. parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  131. parser.add_argument('--include', nargs='+', default=['torchscript', 'onnx', 'coreml'], help='include formats')
  132. parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
  133. parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
  134. parser.add_argument('--train', action='store_true', help='model.train() mode')
  135. parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
  136. parser.add_argument('--dynamic', action='store_true', help='ONNX: dynamic axes')
  137. parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
  138. parser.add_argument('--opset-version', type=int, default=12, help='ONNX: opset version')
  139. opt = parser.parse_args()
  140. return opt
  141. def main(opt):
  142. set_logging()
  143. print(colorstr('export: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  144. run(**vars(opt))
  145. if __name__ == "__main__":
  146. opt = parse_opt()
  147. main(opt)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...