Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

detect.py 8.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
  1. import argparse
  2. import time
  3. from pathlib import Path
  4. import cv2
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. from numpy import random
  8. from models.experimental import attempt_load
  9. from utils.datasets import LoadStreams, LoadImages
  10. from utils.general import check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, \
  11. strip_optimizer, set_logging, increment_path
  12. from utils.plots import plot_one_box
  13. from utils.torch_utils import select_device, load_classifier, time_synchronized
  14. def detect(save_img=False):
  15. source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
  16. webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
  17. ('rtsp://', 'rtmp://', 'http://'))
  18. # Directories
  19. save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
  20. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
  21. # Initialize
  22. set_logging()
  23. device = select_device(opt.device)
  24. half = device.type != 'cpu' # half precision only supported on CUDA
  25. # Load model
  26. model = attempt_load(weights, map_location=device) # load FP32 model
  27. imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
  28. if half:
  29. model.half() # to FP16
  30. # Second-stage classifier
  31. classify = False
  32. if classify:
  33. modelc = load_classifier(name='resnet101', n=2) # initialize
  34. modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()
  35. # Set Dataloader
  36. vid_path, vid_writer = None, None
  37. if webcam:
  38. view_img = True
  39. cudnn.benchmark = True # set True to speed up constant image size inference
  40. dataset = LoadStreams(source, img_size=imgsz)
  41. else:
  42. save_img = True
  43. dataset = LoadImages(source, img_size=imgsz)
  44. # Get names and colors
  45. names = model.module.names if hasattr(model, 'module') else model.names
  46. colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
  47. # Run inference
  48. t0 = time.time()
  49. img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
  50. _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
  51. for path, img, im0s, vid_cap in dataset:
  52. img = torch.from_numpy(img).to(device)
  53. img = img.half() if half else img.float() # uint8 to fp16/32
  54. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  55. if img.ndimension() == 3:
  56. img = img.unsqueeze(0)
  57. # Inference
  58. t1 = time_synchronized()
  59. pred = model(img, augment=opt.augment)[0]
  60. # Apply NMS
  61. pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
  62. t2 = time_synchronized()
  63. # Apply Classifier
  64. if classify:
  65. pred = apply_classifier(pred, modelc, img, im0s)
  66. # Process detections
  67. for i, det in enumerate(pred): # detections per image
  68. if webcam: # batch_size >= 1
  69. p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
  70. else:
  71. p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
  72. p = Path(p) # to Path
  73. save_path = str(save_dir / p.name) # img.jpg
  74. txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
  75. s += '%gx%g ' % img.shape[2:] # print string
  76. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  77. if len(det):
  78. # Rescale boxes from img_size to im0 size
  79. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  80. # Print results
  81. for c in det[:, -1].unique():
  82. n = (det[:, -1] == c).sum() # detections per class
  83. s += f'{n} {names[int(c)]}s, ' # add to string
  84. # Write results
  85. for *xyxy, conf, cls in reversed(det):
  86. if save_txt: # Write to file
  87. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  88. line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format
  89. with open(txt_path + '.txt', 'a') as f:
  90. f.write(('%g ' * len(line)).rstrip() % line + '\n')
  91. if save_img or view_img: # Add bbox to image
  92. label = f'{names[int(cls)]} {conf:.2f}'
  93. plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
  94. # Print time (inference + NMS)
  95. print(f'{s}Done. ({t2 - t1:.3f}s)')
  96. # Stream results
  97. if view_img:
  98. cv2.imshow(str(p), im0)
  99. if cv2.waitKey(1) == ord('q'): # q to quit
  100. raise StopIteration
  101. # Save results (image with detections)
  102. if save_img:
  103. if dataset.mode == 'image':
  104. cv2.imwrite(save_path, im0)
  105. else: # 'video'
  106. if vid_path != save_path: # new video
  107. vid_path = save_path
  108. if isinstance(vid_writer, cv2.VideoWriter):
  109. vid_writer.release() # release previous video writer
  110. fourcc = 'mp4v' # output video codec
  111. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  112. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  113. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  114. vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
  115. vid_writer.write(im0)
  116. if save_txt or save_img:
  117. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
  118. print(f"Results saved to {save_dir}{s}")
  119. print(f'Done. ({time.time() - t0:.3f}s)')
  120. if __name__ == '__main__':
  121. parser = argparse.ArgumentParser()
  122. parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
  123. parser.add_argument('--source', type=str, default='data/images', help='source') # file/folder, 0 for webcam
  124. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  125. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  126. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  127. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  128. parser.add_argument('--view-img', action='store_true', help='display results')
  129. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
  130. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  131. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  132. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  133. parser.add_argument('--augment', action='store_true', help='augmented inference')
  134. parser.add_argument('--update', action='store_true', help='update all models')
  135. parser.add_argument('--project', default='runs/detect', help='save results to project/name')
  136. parser.add_argument('--name', default='exp', help='save results to project/name')
  137. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  138. opt = parser.parse_args()
  139. print(opt)
  140. with torch.no_grad():
  141. if opt.update: # update all models (to fix SourceChangeWarning)
  142. for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
  143. detect()
  144. strip_optimizer(opt.weights)
  145. else:
  146. detect()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...