Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

pipelines.py 32 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
  1. import copy
  2. from abc import ABC, abstractmethod
  3. from typing import List, Optional, Tuple, Union, Iterable
  4. from contextlib import contextmanager
  5. from super_gradients.module_interfaces import SupportsInputShapeCheck
  6. from tqdm import tqdm
  7. import numpy as np
  8. import torch
  9. from super_gradients.training.utils.predict import (
  10. ImagePoseEstimationPrediction,
  11. ImagesPoseEstimationPrediction,
  12. VideoPoseEstimationPrediction,
  13. ImagesDetectionPrediction,
  14. VideoDetectionPrediction,
  15. ImagePrediction,
  16. ImageDetectionPrediction,
  17. ImagesPredictions,
  18. VideoPredictions,
  19. Prediction,
  20. DetectionPrediction,
  21. PoseEstimationPrediction,
  22. ImageClassificationPrediction,
  23. ImagesClassificationPrediction,
  24. ClassificationPrediction,
  25. ImageSegmentationPrediction,
  26. ImagesSegmentationPrediction,
  27. SegmentationPrediction,
  28. VideoSegmentationPrediction,
  29. )
  30. from super_gradients.training.utils.utils import generate_batch, infer_model_device, resolve_torch_device
  31. from super_gradients.training.utils.media.video import includes_video_extension, lazy_load_video
  32. from super_gradients.training.utils.media.image import ImageSource, check_image_typing
  33. from super_gradients.training.utils.media.stream import WebcamStreaming
  34. from super_gradients.training.utils.detection_utils import DetectionPostPredictionCallback
  35. from super_gradients.training.models.sg_module import SgModule
  36. from super_gradients.training.processing.processing import Processing, ComposeProcessing, ImagePermute
  37. from super_gradients.common.abstractions.abstract_logger import get_logger
  38. logger = get_logger(__name__)
  39. @contextmanager
  40. def eval_mode(model: SgModule) -> None:
  41. """Set a model in evaluation mode, undo at the end.
  42. :param model: The model to set in evaluation mode.
  43. """
  44. _starting_mode = model.training
  45. model.eval()
  46. yield
  47. model.train(mode=_starting_mode)
  48. class Pipeline(ABC):
  49. """An abstract base class representing a processing pipeline for a specific task.
  50. The pipeline includes loading images, preprocessing, prediction, and postprocessing.
  51. :param model: The model used for making predictions.
  52. :param image_processor: A single image processor or a list of image processors for preprocessing and postprocessing the images.
  53. :param device: The device on which the model will be run. If None, will run on current model device. Use "cuda" for GPU support.
  54. :param dtype: Specify the dtype of the inputs. If None, will use the dtype of the model's parameters.
  55. :param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
  56. """
  57. def __init__(
  58. self,
  59. model: SgModule,
  60. image_processor: Union[Processing, List[Processing]],
  61. class_names: List[str],
  62. device: Optional[str] = None,
  63. fuse_model: bool = True,
  64. dtype: Optional[torch.dtype] = None,
  65. fp16: bool = True,
  66. ):
  67. model_device: torch.device = infer_model_device(model=model)
  68. if device:
  69. device: torch.device = resolve_torch_device(device=device)
  70. self.device: torch.device = device or model_device
  71. self.dtype = dtype or next(model.parameters()).dtype
  72. self.model = model.to(device) if device and device != model_device else model
  73. self.class_names = class_names
  74. if isinstance(image_processor, list):
  75. image_processor = ComposeProcessing(image_processor)
  76. self.image_processor = image_processor
  77. self.fuse_model = fuse_model # If True, the model will be fused in the first forward pass, to make sure it gets the right input_size
  78. self.fp16 = fp16
  79. def _fuse_model(self, input_example: torch.Tensor):
  80. logger.info("Fusing some of the model's layers. If this takes too much memory, you can deactivate it by setting `fuse_model=False`")
  81. self.model = copy.deepcopy(self.model)
  82. self.model.eval()
  83. self.model.prep_model_for_conversion(input_size=input_example.shape[-2:])
  84. self.fuse_model = False
  85. def __call__(self, inputs: Union[str, ImageSource, List[ImageSource]], batch_size: Optional[int] = 32) -> ImagesPredictions:
  86. """Predict an image or a list of images.
  87. Supported types include:
  88. - str: A string representing either a video, an image or an URL.
  89. - numpy.ndarray: A numpy array representing the image
  90. - torch.Tensor: A PyTorch tensor representing the image
  91. - PIL.Image.Image: A PIL Image object
  92. - List: A list of images of any of the above image types (list of videos not supported).
  93. :param inputs: inputs to the model, which can be any of the above-mentioned types.
  94. :param batch_size: Maximum number of images to process at the same time.
  95. :return: Results of the prediction.
  96. """
  97. if includes_video_extension(inputs):
  98. return self.predict_video(inputs, batch_size)
  99. elif check_image_typing(inputs):
  100. return self.predict_images(inputs, batch_size)
  101. else:
  102. raise ValueError(f"Input {inputs} not supported for prediction.")
  103. def predict_images(self, images: Union[ImageSource, List[ImageSource]], batch_size: Optional[int] = 32) -> Union[ImagesPredictions, ImagePrediction]:
  104. """Predict an image or a list of images.
  105. :param images: Images to predict.
  106. :param batch_size: The size of each batch.
  107. :return: Results of the prediction.
  108. """
  109. from super_gradients.training.utils.media.image import load_images
  110. images = load_images(images)
  111. result_generator = self._generate_prediction_result(images=images, batch_size=batch_size)
  112. return self._combine_image_prediction_to_images(result_generator, n_images=len(images))
  113. def predict_video(self, video_path: str, batch_size: Optional[int] = 32) -> VideoPredictions:
  114. """Predict on a video file, by processing the frames in batches.
  115. :param video_path: Path to the video file.
  116. :param batch_size: The size of each batch.
  117. :return: Results of the prediction.
  118. """
  119. video_frames, fps, num_frames = lazy_load_video(file_path=video_path)
  120. result_generator = self._generate_prediction_result(images=video_frames, batch_size=batch_size)
  121. return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=num_frames)
  122. # return self._combine_image_prediction_to_video(result_generator, fps=fps, n_images=len(video_frames))
  123. def predict_webcam(self) -> None:
  124. """Predict using webcam"""
  125. def _draw_predictions(frame: np.ndarray) -> np.ndarray:
  126. """Draw the predictions on a single frame from the stream."""
  127. frame_prediction = next(iter(self._generate_prediction_result(images=[frame])))
  128. return frame_prediction.draw()
  129. video_streaming = WebcamStreaming(frame_processing_fn=_draw_predictions, fps_update_frequency=1)
  130. video_streaming.run()
  131. def _generate_prediction_result(self, images: Iterable[np.ndarray], batch_size: Optional[int] = None) -> Iterable[ImagePrediction]:
  132. """Run the pipeline on the images as single batch or through multiple batches.
  133. NOTE: A core motivation to have this function as a generator is that it can be used in a lazy way (if images is generator itself),
  134. i.e. without having to load all the images into memory.
  135. :param images: Iterable of numpy arrays representing images.
  136. :param batch_size: The size of each batch.
  137. :return: Iterable of Results object, each containing the results of the prediction and the image.
  138. """
  139. if batch_size is None:
  140. yield from self._generate_prediction_result_single_batch(images)
  141. else:
  142. for batch_images in generate_batch(images, batch_size):
  143. yield from self._generate_prediction_result_single_batch(batch_images)
  144. def _generate_prediction_result_single_batch(self, images: Iterable[np.ndarray]) -> Iterable[ImagePrediction]:
  145. """Run the pipeline on images. The pipeline is made of 4 steps:
  146. 1. Load images - Loading the images into a list of numpy arrays.
  147. 2. Preprocess - Encode the image in the shape/format expected by the model
  148. 3. Predict - Run the model on the preprocessed image
  149. 4. Postprocess - Decode the output of the model so that the predictions are in the shape/format of original image.
  150. :param images: Iterable of numpy arrays representing images.
  151. :return: Iterable of Results object, each containing the results of the prediction and the image.
  152. """
  153. # Make sure the model is on the correct device, as it might have been moved after init
  154. model_device: torch.device = infer_model_device(model=self.model)
  155. if self.device != model_device:
  156. self.model = self.model.to(self.device)
  157. images = list(images) # We need to load all the images into memory, and to reuse it afterwards.
  158. # Preprocess
  159. preprocessed_images, processing_metadatas = [], []
  160. for image in images:
  161. preprocessed_image, processing_metadata = self.image_processor.preprocess_image(image=image.copy())
  162. preprocessed_images.append(preprocessed_image)
  163. processing_metadatas.append(processing_metadata)
  164. reference_shape = preprocessed_images[0].shape
  165. for img in preprocessed_images:
  166. if img.shape != reference_shape:
  167. raise ValueError(
  168. f"Images have different shapes ({img.shape} != {reference_shape})!\n"
  169. f"Either resize the images to the same size, set `skip_image_resizing=False` or pass one image at a time."
  170. )
  171. # Predict
  172. predictions = self.pass_images_through_model(preprocessed_images)
  173. # Postprocess
  174. postprocessed_predictions = []
  175. for image, prediction, processing_metadata in zip(images, predictions, processing_metadatas):
  176. prediction = self.image_processor.postprocess_predictions(predictions=prediction, metadata=processing_metadata)
  177. postprocessed_predictions.append(prediction)
  178. # Yield results one by one
  179. for image, prediction in zip(images, postprocessed_predictions):
  180. yield self._instantiate_image_prediction(image=image, prediction=prediction)
  181. def pass_images_through_model(self, preprocessed_images: List[np.ndarray]) -> List[Prediction]:
  182. with eval_mode(self.model), torch.no_grad(), torch.cuda.amp.autocast(enabled=self.fp16):
  183. torch_inputs = self._prep_inputs_for_model(preprocessed_images)
  184. model_output = self.model(torch_inputs)
  185. predictions = self._decode_model_output(model_output, model_input=torch_inputs)
  186. return predictions
  187. def _prep_inputs_for_model(self, preprocessed_images: List[np.ndarray]) -> torch.Tensor:
  188. torch_inputs = torch.from_numpy(np.array(preprocessed_images)).to(self.device)
  189. torch_inputs = torch_inputs.to(self.dtype)
  190. if isinstance(self.model, SupportsInputShapeCheck):
  191. self.model.validate_input_shape(torch_inputs.size())
  192. if self.fuse_model:
  193. self._fuse_model(torch_inputs)
  194. return torch_inputs
  195. @abstractmethod
  196. def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[Prediction]:
  197. """Decode the model outputs, move each prediction to numpy and store it in a Prediction object.
  198. :param model_output: Direct output of the model, without any post-processing.
  199. :param model_input: Model input (i.e. images after preprocessing).
  200. :return: Model predictions, without any post-processing.
  201. """
  202. raise NotImplementedError
  203. @abstractmethod
  204. def _instantiate_image_prediction(self, image: np.ndarray, prediction: Prediction) -> ImagePrediction:
  205. """Instantiate an object wrapping an image and the pipeline's prediction.
  206. :param image: Image to predict.
  207. :param prediction: Model prediction on that image.
  208. :return: Object wrapping an image and the pipeline's prediction.
  209. """
  210. raise NotImplementedError
  211. @abstractmethod
  212. def _combine_image_prediction_to_images(
  213. self, images_prediction_lst: Iterable[ImagePrediction], n_images: Optional[int] = None
  214. ) -> Union[ImagesPredictions, ImagePrediction]:
  215. """Instantiate an object wrapping the list of images (or ImagePrediction for single prediction)
  216. and the pipeline's predictions on them.
  217. :param images_prediction_lst: List of image predictions.
  218. :param n_images: (Optional) Number of images in the list. This used for tqdm progress bar to work with iterables, but is not required.
  219. :return: Object wrapping the list of image predictions.
  220. """
  221. raise NotImplementedError
  222. @abstractmethod
  223. def _combine_image_prediction_to_video(
  224. self, images_prediction_lst: Iterable[ImagePrediction], fps: float, n_images: Optional[int] = None
  225. ) -> VideoPredictions:
  226. """Instantiate an object holding the video frames and the pipeline's predictions on it.
  227. :param images_prediction_lst: List of image predictions.
  228. :param fps: Frames per second.
  229. :param n_images: (Optional) Number of images in the list. This used for tqdm progress bar to work with iterables, but is not required.
  230. :return: Object wrapping the list of image predictions as a Video.
  231. """
  232. raise NotImplementedError
  233. class DetectionPipeline(Pipeline):
  234. """Pipeline specifically designed for object detection tasks.
  235. The pipeline includes loading images, preprocessing, prediction, and postprocessing.
  236. :param model: The object detection model (instance of SgModule) used for making predictions.
  237. :param class_names: List of class names corresponding to the model's output classes.
  238. :param post_prediction_callback: Callback function to process raw predictions from the model.
  239. :param image_processor: Single image processor or a list of image processors for preprocessing and postprocessing the images.
  240. :param device: The device on which the model will be run. If None, will run on current model device. Use "cuda" for GPU support.
  241. :param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
  242. :param fp16: If True, use mixed precision for inference.
  243. """
  244. def __init__(
  245. self,
  246. model: SgModule,
  247. class_names: List[str],
  248. post_prediction_callback: DetectionPostPredictionCallback,
  249. device: Optional[str] = None,
  250. image_processor: Union[Processing, List[Processing]] = None,
  251. fuse_model: bool = True,
  252. fp16: bool = True,
  253. ):
  254. if isinstance(image_processor, list):
  255. image_processor = ComposeProcessing(image_processor)
  256. has_image_permute = any(isinstance(image_processing, ImagePermute) for image_processing in image_processor.processings)
  257. if not has_image_permute:
  258. image_processor.processings.append(ImagePermute())
  259. super().__init__(
  260. model=model,
  261. device=device,
  262. image_processor=image_processor,
  263. class_names=class_names,
  264. fuse_model=fuse_model,
  265. fp16=fp16,
  266. )
  267. self.post_prediction_callback = post_prediction_callback
  268. def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[DetectionPrediction]:
  269. """Decode the model output, by applying post prediction callback. This includes NMS.
  270. :param model_output: Direct output of the model, without any post-processing.
  271. :param model_input: Model input (i.e. images after preprocessing).
  272. :return: Predicted Bboxes.
  273. """
  274. post_nms_predictions = self.post_prediction_callback(model_output, device=self.device)
  275. return self._decode_detection_model_output(model_input, post_nms_predictions)
  276. @staticmethod
  277. def _decode_detection_model_output(model_input: np.ndarray, post_nms_predictions: List[torch.Tensor]) -> List[DetectionPrediction]:
  278. predictions = []
  279. for prediction, image in zip(post_nms_predictions, model_input):
  280. prediction = prediction if prediction is not None else torch.zeros((0, 6), dtype=torch.float32)
  281. prediction = prediction.detach().cpu().numpy()
  282. predictions.append(
  283. DetectionPrediction(
  284. bboxes=prediction[:, :4],
  285. confidence=prediction[:, 4],
  286. labels=prediction[:, 5].astype(int),
  287. bbox_format="xyxy",
  288. image_shape=image.shape,
  289. )
  290. )
  291. return predictions
  292. def _instantiate_image_prediction(self, image: np.ndarray, prediction: DetectionPrediction) -> ImagePrediction:
  293. return ImageDetectionPrediction(image=image, prediction=prediction, class_names=self.class_names)
  294. def _combine_image_prediction_to_images(
  295. self, images_predictions: Iterable[ImageDetectionPrediction], n_images: Optional[int] = None
  296. ) -> Union[ImagesDetectionPrediction, ImageDetectionPrediction]:
  297. if n_images is not None and n_images == 1:
  298. # Do not show tqdm progress bar if there is only one image
  299. images_predictions = next(iter(images_predictions))
  300. else:
  301. images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Images")]
  302. images_predictions = ImagesDetectionPrediction(_images_prediction_lst=images_predictions)
  303. return images_predictions
  304. def _combine_image_prediction_to_video(
  305. self, images_predictions: Iterable[ImageDetectionPrediction], fps: float, n_images: Optional[int] = None
  306. ) -> VideoDetectionPrediction:
  307. return VideoDetectionPrediction(_images_prediction_gen=images_predictions, fps=fps, n_frames=n_images)
  308. class SlidingWindowDetectionPipeline(DetectionPipeline):
  309. def pass_images_through_model(self, preprocessed_images: List[np.ndarray]) -> List[Prediction]:
  310. with eval_mode(self.model), torch.no_grad(), torch.cuda.amp.autocast(enabled=self.fp16):
  311. torch_inputs = self._prep_inputs_for_model(preprocessed_images)
  312. model_output = self.model(torch_inputs, sliding_window_post_prediction_callback=self.post_prediction_callback)
  313. predictions = self._decode_model_output(model_output, model_input=torch_inputs)
  314. return predictions
  315. def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[DetectionPrediction]:
  316. """Decode the model output, by applying post prediction callback. This includes NMS.
  317. :param model_output: Direct output of the model, without any post-processing.
  318. :param model_input: Model input (i.e. images after preprocessing).
  319. :return: Predicted Bboxes.
  320. """
  321. return self._decode_detection_model_output(model_input, model_output)
  322. def _fuse_model(self, input_example: torch.Tensor):
  323. logger.info("Fusing some of the model's layers. If this takes too much memory, you can deactivate it by setting `fuse_model=False`")
  324. self.model = copy.deepcopy(self.model)
  325. self.model.eval()
  326. self.model.model.prep_model_for_conversion(input_size=input_example.shape[-2:])
  327. self.fuse_model = False
  328. class PoseEstimationPipeline(Pipeline):
  329. """Pipeline specifically designed for pose estimation tasks.
  330. The pipeline includes loading images, preprocessing, prediction, and postprocessing.
  331. :param model: The object detection model (instance of SgModule) used for making predictions.
  332. :param post_prediction_callback: Callback function to process raw predictions from the model.
  333. :param image_processor: Single image processor or a list of image processors for preprocessing and postprocessing the images.
  334. :param device: The device on which the model will be run. If None, will run on current model device. Use "cuda" for GPU support.
  335. :param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
  336. """
  337. def __init__(
  338. self,
  339. model: SgModule,
  340. edge_links: Union[np.ndarray, List[Tuple[int, int]]],
  341. edge_colors: Union[np.ndarray, List[Tuple[int, int, int]]],
  342. keypoint_colors: Union[np.ndarray, List[Tuple[int, int, int]]],
  343. post_prediction_callback,
  344. device: Optional[str] = None,
  345. image_processor: Union[Processing, List[Processing]] = None,
  346. fuse_model: bool = True,
  347. fp16: bool = True,
  348. ):
  349. if isinstance(image_processor, list):
  350. image_processor = ComposeProcessing(image_processor)
  351. super().__init__(
  352. model=model,
  353. device=device,
  354. image_processor=image_processor,
  355. class_names=None,
  356. fuse_model=fuse_model,
  357. fp16=fp16,
  358. )
  359. self.post_prediction_callback = post_prediction_callback
  360. self.edge_links = np.asarray(edge_links, dtype=int)
  361. self.edge_colors = np.asarray(edge_colors, dtype=int)
  362. self.keypoint_colors = np.asarray(keypoint_colors, dtype=int)
  363. def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[PoseEstimationPrediction]:
  364. """Decode the model output, by applying post prediction callback. This includes NMS.
  365. :param model_output: Direct output of the model, without any post-processing.
  366. :param model_input: Model input (i.e. images after preprocessing).
  367. :return: Predicted Bboxes.
  368. """
  369. list_of_predictions = self.post_prediction_callback(model_output)
  370. decoded_predictions = []
  371. for image_level_predictions, image in zip(list_of_predictions, model_input):
  372. decoded_predictions.append(
  373. PoseEstimationPrediction(
  374. poses=image_level_predictions.poses.cpu().numpy() if torch.is_tensor(image_level_predictions.poses) else image_level_predictions.poses,
  375. scores=image_level_predictions.scores.cpu().numpy() if torch.is_tensor(image_level_predictions.scores) else image_level_predictions.scores,
  376. bboxes_xyxy=(
  377. image_level_predictions.bboxes_xyxy.cpu().numpy()
  378. if torch.is_tensor(image_level_predictions.bboxes_xyxy)
  379. else image_level_predictions.bboxes_xyxy
  380. ),
  381. image_shape=image.shape,
  382. edge_links=self.edge_links,
  383. edge_colors=self.edge_colors,
  384. keypoint_colors=self.keypoint_colors,
  385. )
  386. )
  387. return decoded_predictions
  388. def _instantiate_image_prediction(self, image: np.ndarray, prediction: PoseEstimationPrediction) -> ImagePrediction:
  389. return ImagePoseEstimationPrediction(image=image, prediction=prediction, class_names=self.class_names)
  390. def _combine_image_prediction_to_images(
  391. self, images_predictions: Iterable[PoseEstimationPrediction], n_images: Optional[int] = None
  392. ) -> Union[ImagesPoseEstimationPrediction, ImagePoseEstimationPrediction]:
  393. if n_images is not None and n_images == 1:
  394. # Do not show tqdm progress bar if there is only one image
  395. images_predictions = next(iter(images_predictions))
  396. else:
  397. images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Images")]
  398. images_predictions = ImagesPoseEstimationPrediction(_images_prediction_lst=images_predictions)
  399. return images_predictions
  400. def _combine_image_prediction_to_video(
  401. self, images_predictions: Iterable[ImageDetectionPrediction], fps: float, n_images: Optional[int] = None
  402. ) -> VideoPoseEstimationPrediction:
  403. return VideoPoseEstimationPrediction(_images_prediction_gen=images_predictions, fps=fps, n_frames=n_images)
  404. class ClassificationPipeline(Pipeline):
  405. """Pipeline specifically designed for Image Classification tasks.
  406. The pipeline includes loading images, preprocessing, prediction, and postprocessing.
  407. :param model: The classification model (instance of SgModule) used for making predictions.
  408. :param class_names: List of class names corresponding to the model's output classes.
  409. :param image_processor: Single image processor or a list of image processors for preprocessing and postprocessing the images.
  410. :param device: The device on which the model will be run. If None, will run on current model device. Use "cuda" for GPU support.
  411. :param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
  412. :param fp16: If True, use mixed precision for inference.
  413. """
  414. def __init__(
  415. self,
  416. model: SgModule,
  417. class_names: List[str],
  418. device: Optional[str] = None,
  419. image_processor: Union[Processing, List[Processing]] = None,
  420. fuse_model: bool = True,
  421. fp16: bool = True,
  422. ):
  423. super().__init__(
  424. model=model,
  425. device=device,
  426. image_processor=image_processor,
  427. class_names=class_names,
  428. fuse_model=fuse_model,
  429. fp16=fp16,
  430. )
  431. def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[ClassificationPrediction]:
  432. """Decode the model output
  433. :param model_output: Direct output of the model, without any post-processing. Tensor of shape [B, C]
  434. :param model_input: Model input (i.e. images after preprocessing).
  435. :return: Predicted Bboxes.
  436. """
  437. pred_scores, pred_labels = torch.max(model_output.softmax(dim=1), 1)
  438. pred_labels = pred_labels.detach().cpu().numpy() # [B,1]
  439. pred_scores = pred_scores.detach().cpu().numpy() # [B,1]
  440. predictions = list()
  441. for prediction, confidence, image_input in zip(pred_labels, pred_scores, model_input):
  442. predictions.append(ClassificationPrediction(confidence=float(confidence), label=int(prediction), image_shape=image_input.shape))
  443. return predictions
  444. def _instantiate_image_prediction(self, image: np.ndarray, prediction: ClassificationPrediction) -> ImagePrediction:
  445. return ImageClassificationPrediction(image=image, prediction=prediction, class_names=self.class_names)
  446. def _combine_image_prediction_to_images(
  447. self, images_predictions: Iterable[ImageClassificationPrediction], n_images: Optional[int] = None
  448. ) -> Union[ImagesClassificationPrediction, ImageClassificationPrediction]:
  449. if n_images is not None and n_images == 1:
  450. # Do not show tqdm progress bar if there is only one image
  451. images_predictions = next(iter(images_predictions))
  452. else:
  453. images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Images")]
  454. images_predictions = ImagesClassificationPrediction(_images_prediction_lst=images_predictions)
  455. return images_predictions
  456. def _combine_image_prediction_to_video(
  457. self, images_predictions: Iterable[ImageDetectionPrediction], fps: float, n_images: Optional[int] = None
  458. ) -> ImagesClassificationPrediction:
  459. raise NotImplementedError("This feature is not available for Classification task")
  460. class SegmentationPipeline(Pipeline):
  461. """Pipeline specifically designed for segmentation tasks.
  462. The pipeline includes loading images, preprocessing, prediction, and postprocessing.
  463. :param model: The object detection model (instance of SgModule) used for making predictions.
  464. :param class_names: List of class names corresponding to the model's output classes.
  465. :param post_prediction_callback: Callback function to process raw predictions from the model.
  466. :param image_processor: Single image processor or a list of image processors for preprocessing and postprocessing the images.
  467. :param device: The device on which the model will be run. If None, will run on current model device. Use "cuda" for GPU support.
  468. :param fuse_model: If True, create a copy of the model, and fuse some of its layers to increase performance. This increases memory usage.
  469. :param fp16: If True, use mixed precision for inference.
  470. """
  471. def __init__(
  472. self,
  473. model: SgModule,
  474. class_names: List[str],
  475. device: Optional[str] = None,
  476. image_processor: Optional[Processing] = None,
  477. fuse_model: bool = True,
  478. fp16: bool = True,
  479. ):
  480. super().__init__(model=model, device=device, image_processor=image_processor, class_names=class_names, fuse_model=fuse_model, fp16=fp16)
  481. def _decode_model_output(self, model_output: Union[List, Tuple, torch.Tensor], model_input: np.ndarray) -> List[SegmentationPrediction]:
  482. """Decode the model output, by applying post prediction callback. This includes NMS.
  483. :param model_output: Direct output of the model, without any post-processing.
  484. :param model_input: Model input (i.e. images after preprocessing).
  485. :return: Predicted Bboxes.
  486. """
  487. if type(model_output) is tuple:
  488. model_output = model_output(0)
  489. if model_output.size(1) == 1:
  490. class_predication = torch.sigmoid(model_output).gt(0.5).squeeze(1).long()
  491. else:
  492. class_predication = torch.argmax(model_output, dim=1)
  493. class_predication = class_predication.detach().cpu().numpy()
  494. predictions = []
  495. for prediction, image in zip(class_predication, model_input):
  496. predictions.append(
  497. SegmentationPrediction(
  498. segmentation_map=prediction,
  499. segmentation_map_shape=prediction.shape,
  500. image_shape=image.shape[-2:],
  501. )
  502. )
  503. return predictions
  504. def _instantiate_image_prediction(self, image: np.ndarray, prediction: SegmentationPrediction) -> ImagePrediction:
  505. return ImageSegmentationPrediction(image=image, prediction=prediction, class_names=self.class_names)
  506. def _combine_image_prediction_to_images(
  507. self, images_predictions: Iterable[ImageSegmentationPrediction], n_images: Optional[int] = None
  508. ) -> Union[ImagesSegmentationPrediction, ImageSegmentationPrediction]:
  509. if n_images is not None and n_images == 1:
  510. # Do not show tqdm progress bar if there is only one image
  511. images_predictions = next(iter(images_predictions))
  512. else:
  513. images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Images")]
  514. images_predictions = ImagesSegmentationPrediction(_images_prediction_lst=images_predictions)
  515. return images_predictions
  516. def _combine_image_prediction_to_video(
  517. self, images_predictions: Iterable[ImageSegmentationPrediction], fps: float, n_images: Optional[int] = None
  518. ) -> VideoSegmentationPrediction:
  519. images_predictions = [image_predictions for image_predictions in tqdm(images_predictions, total=n_images, desc="Predicting Video")]
  520. return VideoSegmentationPrediction(_images_prediction_lst=images_predictions, fps=fps)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...