Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

hubconf.py 5.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
  1. """YOLOv5 PyTorch Hub models https://pytorch.org/hub/ultralytics_yolov5/
  2. Usage:
  3. import torch
  4. model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
  5. """
  6. import torch
  7. def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  8. """Creates a specified YOLOv5 model
  9. Arguments:
  10. name (str): name of model, i.e. 'yolov5s'
  11. pretrained (bool): load pretrained weights into the model
  12. channels (int): number of input channels
  13. classes (int): number of model classes
  14. autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
  15. verbose (bool): print all information to screen
  16. device (str, torch.device, None): device to use for model parameters
  17. Returns:
  18. YOLOv5 pytorch model
  19. """
  20. from pathlib import Path
  21. from models.yolo import Model, attempt_load
  22. from utils.general import check_requirements, set_logging
  23. from utils.google_utils import attempt_download
  24. from utils.torch_utils import select_device
  25. check_requirements(requirements=Path(__file__).parent / 'requirements.txt',
  26. exclude=('tensorboard', 'thop', 'opencv-python'))
  27. set_logging(verbose=verbose)
  28. fname = Path(name).with_suffix('.pt') # checkpoint filename
  29. try:
  30. if pretrained and channels == 3 and classes == 80:
  31. model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model
  32. else:
  33. cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
  34. model = Model(cfg, channels, classes) # create model
  35. if pretrained:
  36. ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu')) # load
  37. msd = model.state_dict() # model state_dict
  38. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  39. csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
  40. model.load_state_dict(csd, strict=False) # load
  41. if len(ckpt['model'].names) == classes:
  42. model.names = ckpt['model'].names # set class names attribute
  43. if autoshape:
  44. model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
  45. device = select_device('0' if torch.cuda.is_available() else 'cpu') if device is None else torch.device(device)
  46. return model.to(device)
  47. except Exception as e:
  48. help_url = 'https://github.com/ultralytics/yolov5/issues/36'
  49. s = 'Cache may be out of date, try `force_reload=True`. See %s for help.' % help_url
  50. raise Exception(s) from e
  51. def custom(path='path/to/model.pt', autoshape=True, verbose=True, device=None):
  52. # YOLOv5 custom or local model
  53. return _create(path, autoshape=autoshape, verbose=verbose, device=device)
  54. def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  55. # YOLOv5-small model https://github.com/ultralytics/yolov5
  56. return _create('yolov5s', pretrained, channels, classes, autoshape, verbose, device)
  57. def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  58. # YOLOv5-medium model https://github.com/ultralytics/yolov5
  59. return _create('yolov5m', pretrained, channels, classes, autoshape, verbose, device)
  60. def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  61. # YOLOv5-large model https://github.com/ultralytics/yolov5
  62. return _create('yolov5l', pretrained, channels, classes, autoshape, verbose, device)
  63. def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  64. # YOLOv5-xlarge model https://github.com/ultralytics/yolov5
  65. return _create('yolov5x', pretrained, channels, classes, autoshape, verbose, device)
  66. def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  67. # YOLOv5-small-P6 model https://github.com/ultralytics/yolov5
  68. return _create('yolov5s6', pretrained, channels, classes, autoshape, verbose, device)
  69. def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  70. # YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5
  71. return _create('yolov5m6', pretrained, channels, classes, autoshape, verbose, device)
  72. def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  73. # YOLOv5-large-P6 model https://github.com/ultralytics/yolov5
  74. return _create('yolov5l6', pretrained, channels, classes, autoshape, verbose, device)
  75. def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
  76. # YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5
  77. return _create('yolov5x6', pretrained, channels, classes, autoshape, verbose, device)
  78. if __name__ == '__main__':
  79. model = _create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True, verbose=True) # pretrained
  80. # model = custom(path='path/to/model.pt') # custom
  81. # Verify inference
  82. import cv2
  83. import numpy as np
  84. from PIL import Image
  85. imgs = ['data/images/zidane.jpg', # filename
  86. 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg', # URI
  87. cv2.imread('data/images/bus.jpg')[:, :, ::-1], # OpenCV
  88. Image.open('data/images/bus.jpg'), # PIL
  89. np.zeros((320, 640, 3))] # numpy
  90. results = model(imgs) # batched inference
  91. results.print()
  92. results.save()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...