Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#20413 YOLOE: Fix visual prompt training

Merged
Ghost merged 1 commits into Ultralytics:main from ultralytics:yoloe-vp-fix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
  1. # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
  2. """
  3. Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector.
  4. RT-DETR offers real-time performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT.
  5. It features an efficient hybrid encoder and IoU-aware query selection for enhanced detection accuracy.
  6. References:
  7. https://arxiv.org/pdf/2304.08069.pdf
  8. """
  9. from ultralytics.engine.model import Model
  10. from ultralytics.nn.tasks import RTDETRDetectionModel
  11. from .predict import RTDETRPredictor
  12. from .train import RTDETRTrainer
  13. from .val import RTDETRValidator
  14. class RTDETR(Model):
  15. """
  16. Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
  17. This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware query
  18. selection, and adaptable inference speed.
  19. Attributes:
  20. model (str): Path to the pre-trained model.
  21. Examples:
  22. >>> from ultralytics import RTDETR
  23. >>> model = RTDETR("rtdetr-l.pt")
  24. >>> results = model("image.jpg")
  25. """
  26. def __init__(self, model: str = "rtdetr-l.pt") -> None:
  27. """
  28. Initialize the RT-DETR model with the given pre-trained model file.
  29. Args:
  30. model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.
  31. Raises:
  32. NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
  33. """
  34. super().__init__(model=model, task="detect")
  35. @property
  36. def task_map(self) -> dict:
  37. """
  38. Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
  39. Returns:
  40. (dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
  41. """
  42. return {
  43. "detect": {
  44. "predictor": RTDETRPredictor,
  45. "validator": RTDETRValidator,
  46. "trainer": RTDETRTrainer,
  47. "model": RTDETRDetectionModel,
  48. }
  49. }
Discard
Tip!

Press p or to see the previous file or, n or to see the next file