Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

hubconf.py 4.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
  1. """File for accessing YOLOv5 via PyTorch Hub https://pytorch.org/hub/
  2. Usage:
  3. import torch
  4. model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80)
  5. """
  6. from pathlib import Path
  7. import torch
  8. from models.yolo import Model
  9. from utils.general import set_logging
  10. from utils.google_utils import attempt_download
  11. dependencies = ['torch', 'yaml']
  12. set_logging()
  13. def create(name, pretrained, channels, classes, autoshape):
  14. """Creates a specified YOLOv5 model
  15. Arguments:
  16. name (str): name of model, i.e. 'yolov5s'
  17. pretrained (bool): load pretrained weights into the model
  18. channels (int): number of input channels
  19. classes (int): number of model classes
  20. Returns:
  21. pytorch model
  22. """
  23. config = Path(__file__).parent / 'models' / f'{name}.yaml' # model.yaml path
  24. try:
  25. model = Model(config, channels, classes)
  26. if pretrained:
  27. fname = f'{name}.pt' # checkpoint filename
  28. attempt_download(fname) # download if not found locally
  29. ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
  30. state_dict = ckpt['model'].float().state_dict() # to FP32
  31. state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
  32. model.load_state_dict(state_dict, strict=False) # load
  33. if len(ckpt['model'].names) == classes:
  34. model.names = ckpt['model'].names # set class names attribute
  35. if autoshape:
  36. model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
  37. return model
  38. except Exception as e:
  39. help_url = 'https://github.com/ultralytics/yolov5/issues/36'
  40. s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url
  41. raise Exception(s) from e
  42. def yolov5s(pretrained=False, channels=3, classes=80, autoshape=True):
  43. """YOLOv5-small model from https://github.com/ultralytics/yolov5
  44. Arguments:
  45. pretrained (bool): load pretrained weights into the model, default=False
  46. channels (int): number of input channels, default=3
  47. classes (int): number of model classes, default=80
  48. Returns:
  49. pytorch model
  50. """
  51. return create('yolov5s', pretrained, channels, classes, autoshape)
  52. def yolov5m(pretrained=False, channels=3, classes=80, autoshape=True):
  53. """YOLOv5-medium model from https://github.com/ultralytics/yolov5
  54. Arguments:
  55. pretrained (bool): load pretrained weights into the model, default=False
  56. channels (int): number of input channels, default=3
  57. classes (int): number of model classes, default=80
  58. Returns:
  59. pytorch model
  60. """
  61. return create('yolov5m', pretrained, channels, classes, autoshape)
  62. def yolov5l(pretrained=False, channels=3, classes=80, autoshape=True):
  63. """YOLOv5-large model from https://github.com/ultralytics/yolov5
  64. Arguments:
  65. pretrained (bool): load pretrained weights into the model, default=False
  66. channels (int): number of input channels, default=3
  67. classes (int): number of model classes, default=80
  68. Returns:
  69. pytorch model
  70. """
  71. return create('yolov5l', pretrained, channels, classes, autoshape)
  72. def yolov5x(pretrained=False, channels=3, classes=80, autoshape=True):
  73. """YOLOv5-xlarge model from https://github.com/ultralytics/yolov5
  74. Arguments:
  75. pretrained (bool): load pretrained weights into the model, default=False
  76. channels (int): number of input channels, default=3
  77. classes (int): number of model classes, default=80
  78. Returns:
  79. pytorch model
  80. """
  81. return create('yolov5x', pretrained, channels, classes, autoshape)
  82. def custom(path_or_model='path/to/model.pt', autoshape=True):
  83. """YOLOv5-custom model from https://github.com/ultralytics/yolov5
  84. Arguments (3 options):
  85. path_or_model (str): 'path/to/model.pt'
  86. path_or_model (dict): torch.load('path/to/model.pt')
  87. path_or_model (nn.Module): torch.load('path/to/model.pt')['model']
  88. Returns:
  89. pytorch model
  90. """
  91. model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
  92. if isinstance(model, dict):
  93. model = model['model'] # load model
  94. hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
  95. hub_model.load_state_dict(model.float().state_dict()) # load state_dict
  96. hub_model.names = model.names # class names
  97. return hub_model.autoshape() if autoshape else hub_model
  98. if __name__ == '__main__':
  99. model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True) # pretrained example
  100. # model = custom(path_or_model='path/to/model.pt') # custom example
  101. # Verify inference
  102. from PIL import Image
  103. imgs = [Image.open(x) for x in Path('data/images').glob('*.jpg')]
  104. results = model(imgs)
  105. results.print()
  106. results.save()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...