1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
|
- # Ultralytics ๐ AGPL-3.0 License - https://ultralytics.com/license
- """
- Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector.
- RT-DETR offers real-time performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT.
- It features an efficient hybrid encoder and IoU-aware query selection for enhanced detection accuracy.
- References:
- https://arxiv.org/pdf/2304.08069.pdf
- """
- from ultralytics.engine.model import Model
- from ultralytics.nn.tasks import RTDETRDetectionModel
- from .predict import RTDETRPredictor
- from .train import RTDETRTrainer
- from .val import RTDETRValidator
- class RTDETR(Model):
- """
- Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
- This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware query
- selection, and adaptable inference speed.
- Attributes:
- model (str): Path to the pre-trained model.
- Examples:
- >>> from ultralytics import RTDETR
- >>> model = RTDETR("rtdetr-l.pt")
- >>> results = model("image.jpg")
- """
- def __init__(self, model: str = "rtdetr-l.pt") -> None:
- """
- Initialize the RT-DETR model with the given pre-trained model file.
- Args:
- model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.
- Raises:
- NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
- """
- super().__init__(model=model, task="detect")
- @property
- def task_map(self) -> dict:
- """
- Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
- Returns:
- (dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
- """
- return {
- "detect": {
- "predictor": RTDETRPredictor,
- "validator": RTDETRValidator,
- "trainer": RTDETRTrainer,
- "model": RTDETRDetectionModel,
- }
- }
|