Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

hubconf.py 4.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
  1. """File for accessing YOLOv5 via PyTorch Hub https://pytorch.org/hub/
  2. Usage:
  3. import torch
  4. model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80)
  5. """
  6. from pathlib import Path
  7. import torch
  8. from models.yolo import Model
  9. from utils.general import set_logging
  10. from utils.google_utils import attempt_download
  11. dependencies = ['torch', 'yaml']
  12. set_logging()
  13. def create(name, pretrained, channels, classes):
  14. """Creates a specified YOLOv5 model
  15. Arguments:
  16. name (str): name of model, i.e. 'yolov5s'
  17. pretrained (bool): load pretrained weights into the model
  18. channels (int): number of input channels
  19. classes (int): number of model classes
  20. Returns:
  21. pytorch model
  22. """
  23. config = Path(__file__).parent / 'models' / f'{name}.yaml' # model.yaml path
  24. try:
  25. model = Model(config, channels, classes)
  26. if pretrained:
  27. fname = f'{name}.pt' # checkpoint filename
  28. attempt_download(fname) # download if not found locally
  29. ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
  30. state_dict = ckpt['model'].float().state_dict() # to FP32
  31. state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
  32. model.load_state_dict(state_dict, strict=False) # load
  33. if len(ckpt['model'].names) == classes:
  34. model.names = ckpt['model'].names # set class names attribute
  35. # model = model.autoshape() # for PIL/cv2/np inputs and NMS
  36. return model
  37. except Exception as e:
  38. help_url = 'https://github.com/ultralytics/yolov5/issues/36'
  39. s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url
  40. raise Exception(s) from e
  41. def yolov5s(pretrained=False, channels=3, classes=80):
  42. """YOLOv5-small model from https://github.com/ultralytics/yolov5
  43. Arguments:
  44. pretrained (bool): load pretrained weights into the model, default=False
  45. channels (int): number of input channels, default=3
  46. classes (int): number of model classes, default=80
  47. Returns:
  48. pytorch model
  49. """
  50. return create('yolov5s', pretrained, channels, classes)
  51. def yolov5m(pretrained=False, channels=3, classes=80):
  52. """YOLOv5-medium model from https://github.com/ultralytics/yolov5
  53. Arguments:
  54. pretrained (bool): load pretrained weights into the model, default=False
  55. channels (int): number of input channels, default=3
  56. classes (int): number of model classes, default=80
  57. Returns:
  58. pytorch model
  59. """
  60. return create('yolov5m', pretrained, channels, classes)
  61. def yolov5l(pretrained=False, channels=3, classes=80):
  62. """YOLOv5-large model from https://github.com/ultralytics/yolov5
  63. Arguments:
  64. pretrained (bool): load pretrained weights into the model, default=False
  65. channels (int): number of input channels, default=3
  66. classes (int): number of model classes, default=80
  67. Returns:
  68. pytorch model
  69. """
  70. return create('yolov5l', pretrained, channels, classes)
  71. def yolov5x(pretrained=False, channels=3, classes=80):
  72. """YOLOv5-xlarge model from https://github.com/ultralytics/yolov5
  73. Arguments:
  74. pretrained (bool): load pretrained weights into the model, default=False
  75. channels (int): number of input channels, default=3
  76. classes (int): number of model classes, default=80
  77. Returns:
  78. pytorch model
  79. """
  80. return create('yolov5x', pretrained, channels, classes)
  81. def custom(path_or_model='path/to/model.pt'):
  82. """YOLOv5-custom model from https://github.com/ultralytics/yolov5
  83. Arguments (3 options):
  84. path_or_model (str): 'path/to/model.pt'
  85. path_or_model (dict): torch.load('path/to/model.pt')
  86. path_or_model (nn.Module): torch.load('path/to/model.pt')['model']
  87. Returns:
  88. pytorch model
  89. """
  90. model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
  91. if isinstance(model, dict):
  92. model = model['model'] # load model
  93. hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
  94. hub_model.load_state_dict(model.float().state_dict()) # load state_dict
  95. hub_model.names = model.names # class names
  96. return hub_model
  97. if __name__ == '__main__':
  98. model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # pretrained example
  99. # model = custom(path_or_model='path/to/model.pt') # custom example
  100. model = model.autoshape() # for PIL/cv2/np inputs and NMS
  101. # Verify inference
  102. from PIL import Image
  103. imgs = [Image.open(x) for x in Path('data/images').glob('*.jpg')]
  104. results = model(imgs)
  105. results.show()
  106. results.print()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...