1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
|
- """File for accessing YOLOv5 via PyTorch Hub https://pytorch.org/hub/
- Usage:
- import torch
- model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80)
- """
- from pathlib import Path
- import torch
- from models.yolo import Model
- from utils.general import set_logging
- from utils.google_utils import attempt_download
- dependencies = ['torch', 'yaml']
- set_logging()
- def create(name, pretrained, channels, classes, autoshape):
- """Creates a specified YOLOv5 model
- Arguments:
- name (str): name of model, i.e. 'yolov5s'
- pretrained (bool): load pretrained weights into the model
- channels (int): number of input channels
- classes (int): number of model classes
- Returns:
- pytorch model
- """
- config = Path(__file__).parent / 'models' / f'{name}.yaml' # model.yaml path
- try:
- model = Model(config, channels, classes)
- if pretrained:
- fname = f'{name}.pt' # checkpoint filename
- attempt_download(fname) # download if not found locally
- ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
- state_dict = ckpt['model'].float().state_dict() # to FP32
- state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
- model.load_state_dict(state_dict, strict=False) # load
- if len(ckpt['model'].names) == classes:
- model.names = ckpt['model'].names # set class names attribute
- if autoshape:
- model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
- return model
- except Exception as e:
- help_url = 'https://github.com/ultralytics/yolov5/issues/36'
- s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url
- raise Exception(s) from e
- def yolov5s(pretrained=False, channels=3, classes=80, autoshape=True):
- """YOLOv5-small model from https://github.com/ultralytics/yolov5
- Arguments:
- pretrained (bool): load pretrained weights into the model, default=False
- channels (int): number of input channels, default=3
- classes (int): number of model classes, default=80
- Returns:
- pytorch model
- """
- return create('yolov5s', pretrained, channels, classes, autoshape)
- def yolov5m(pretrained=False, channels=3, classes=80, autoshape=True):
- """YOLOv5-medium model from https://github.com/ultralytics/yolov5
- Arguments:
- pretrained (bool): load pretrained weights into the model, default=False
- channels (int): number of input channels, default=3
- classes (int): number of model classes, default=80
- Returns:
- pytorch model
- """
- return create('yolov5m', pretrained, channels, classes, autoshape)
- def yolov5l(pretrained=False, channels=3, classes=80, autoshape=True):
- """YOLOv5-large model from https://github.com/ultralytics/yolov5
- Arguments:
- pretrained (bool): load pretrained weights into the model, default=False
- channels (int): number of input channels, default=3
- classes (int): number of model classes, default=80
- Returns:
- pytorch model
- """
- return create('yolov5l', pretrained, channels, classes, autoshape)
- def yolov5x(pretrained=False, channels=3, classes=80, autoshape=True):
- """YOLOv5-xlarge model from https://github.com/ultralytics/yolov5
- Arguments:
- pretrained (bool): load pretrained weights into the model, default=False
- channels (int): number of input channels, default=3
- classes (int): number of model classes, default=80
- Returns:
- pytorch model
- """
- return create('yolov5x', pretrained, channels, classes, autoshape)
- def custom(path_or_model='path/to/model.pt', autoshape=True):
- """YOLOv5-custom model from https://github.com/ultralytics/yolov5
- Arguments (3 options):
- path_or_model (str): 'path/to/model.pt'
- path_or_model (dict): torch.load('path/to/model.pt')
- path_or_model (nn.Module): torch.load('path/to/model.pt')['model']
- Returns:
- pytorch model
- """
- model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
- if isinstance(model, dict):
- model = model['model'] # load model
- hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
- hub_model.load_state_dict(model.float().state_dict()) # load state_dict
- hub_model.names = model.names # class names
- return hub_model.autoshape() if autoshape else hub_model
- if __name__ == '__main__':
- model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True) # pretrained example
- # model = custom(path_or_model='path/to/model.pt') # custom example
- # Verify inference
- from PIL import Image
- imgs = [Image.open(x) for x in Path('data/images').glob('*.jpg')]
- results = model(imgs)
- results.print()
- results.save()
|