Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

hubconf.py 5.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. PyTorch Hub models https://pytorch.org/hub/ultralytics_yolov5/
  4. Usage:
  5. import torch
  6. model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
  7. """
  8. import torch
  9. def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  10. """Creates a specified YOLOv5 model
  11. Arguments:
  12. name (str): name of model, i.e. 'yolov5s'
  13. pretrained (bool): load pretrained weights into the model
  14. channels (int): number of input channels
  15. classes (int): number of model classes
  16. autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
  17. verbose (bool): print all information to screen
  18. device (str, torch.device, None): device to use for model parameters
  19. Returns:
  20. YOLOv5 pytorch model
  21. """
  22. from pathlib import Path
  23. from models.yolo import Model
  24. from models.experimental import attempt_load
  25. from utils.general import check_requirements, set_logging
  26. from utils.downloads import attempt_download
  27. from utils.torch_utils import select_device
  28. file = Path(__file__).resolve()
  29. check_requirements(exclude=('tensorboard', 'thop', 'opencv-python'))
  30. set_logging(verbose=verbose)
  31. save_dir = Path('') if str(name).endswith('.pt') else file.parent
  32. path = (save_dir / name).with_suffix('.pt') # checkpoint path
  33. try:
  34. device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)
  35. if pretrained and channels == 3 and classes == 80:
  36. model = attempt_load(path, map_location=device) # download/load FP32 model
  37. else:
  38. cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
  39. model = Model(cfg, channels, classes) # create model
  40. if pretrained:
  41. ckpt = torch.load(attempt_download(path), map_location=device) # load
  42. msd = model.state_dict() # model state_dict
  43. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  44. csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
  45. model.load_state_dict(csd, strict=False) # load
  46. if len(ckpt['model'].names) == classes:
  47. model.names = ckpt['model'].names # set class names attribute
  48. if autoshape:
  49. model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
  50. return model.to(device)
  51. except Exception as e:
  52. help_url = 'https://github.com/ultralytics/yolov5/issues/36'
  53. s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url
  54. raise Exception(s) from e
  55. def custom(path='path/to/model.pt', autoshape=True, verbose=True, device=None):
  56. # YOLOv5 custom or local model
  57. return _create(path, autoshape=autoshape, verbose=verbose, device=device)
  58. def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  59. # YOLOv5-small model https://github.com/ultralytics/yolov5
  60. return _create('yolov5s', pretrained, channels, classes, autoshape, verbose, device)
  61. def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  62. # YOLOv5-medium model https://github.com/ultralytics/yolov5
  63. return _create('yolov5m', pretrained, channels, classes, autoshape, verbose, device)
  64. def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  65. # YOLOv5-large model https://github.com/ultralytics/yolov5
  66. return _create('yolov5l', pretrained, channels, classes, autoshape, verbose, device)
  67. def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  68. # YOLOv5-xlarge model https://github.com/ultralytics/yolov5
  69. return _create('yolov5x', pretrained, channels, classes, autoshape, verbose, device)
  70. def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  71. # YOLOv5-small-P6 model https://github.com/ultralytics/yolov5
  72. return _create('yolov5s6', pretrained, channels, classes, autoshape, verbose, device)
  73. def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  74. # YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5
  75. return _create('yolov5m6', pretrained, channels, classes, autoshape, verbose, device)
  76. def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  77. # YOLOv5-large-P6 model https://github.com/ultralytics/yolov5
  78. return _create('yolov5l6', pretrained, channels, classes, autoshape, verbose, device)
  79. def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  80. # YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5
  81. return _create('yolov5x6', pretrained, channels, classes, autoshape, verbose, device)
  82. if __name__ == '__main__':
  83. model = _create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True, verbose=True) # pretrained
  84. # model = custom(path='path/to/model.pt') # custom
  85. # Verify inference
  86. import cv2
  87. import numpy as np
  88. from PIL import Image
  89. from pathlib import Path
  90. imgs = ['data/images/zidane.jpg', # filename
  91. Path('data/images/zidane.jpg'), # Path
  92. 'https://ultralytics.com/images/zidane.jpg', # URI
  93. cv2.imread('data/images/bus.jpg')[:, :, ::-1], # OpenCV
  94. Image.open('data/images/bus.jpg'), # PIL
  95. np.zeros((320, 640, 3))] # numpy
  96. results = model(imgs) # batched inference
  97. results.print()
  98. results.save()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...