Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 5.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. from copy import copy
  3. from pathlib import Path
  4. from typing import Dict, Optional, Union
  5. from ultralytics.models import yolo
  6. from ultralytics.nn.tasks import SegmentationModel
  7. from ultralytics.utils import DEFAULT_CFG, RANK
  8. from ultralytics.utils.plotting import plot_images, plot_results
  9. class SegmentationTrainer(yolo.detect.DetectionTrainer):
  10. """
  11. A class extending the DetectionTrainer class for training based on a segmentation model.
  12. This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
  13. functionality including model initialization, validation, and visualization.
  14. Attributes:
  15. loss_names (Tuple[str]): Names of the loss components used during training.
  16. Examples:
  17. >>> from ultralytics.models.yolo.segment import SegmentationTrainer
  18. >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
  19. >>> trainer = SegmentationTrainer(overrides=args)
  20. >>> trainer.train()
  21. """
  22. def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict] = None, _callbacks=None):
  23. """
  24. Initialize a SegmentationTrainer object.
  25. This initializes a trainer for segmentation tasks, extending the detection trainer with segmentation-specific
  26. functionality. It sets the task to 'segment' and prepares the trainer for training segmentation models.
  27. Args:
  28. cfg (dict): Configuration dictionary with default training settings.
  29. overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
  30. _callbacks (list, optional): List of callback functions to be executed during training.
  31. Examples:
  32. >>> from ultralytics.models.yolo.segment import SegmentationTrainer
  33. >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
  34. >>> trainer = SegmentationTrainer(overrides=args)
  35. >>> trainer.train()
  36. """
  37. if overrides is None:
  38. overrides = {}
  39. overrides["task"] = "segment"
  40. super().__init__(cfg, overrides, _callbacks)
  41. def get_model(
  42. self, cfg: Optional[Union[Dict, str]] = None, weights: Optional[Union[str, Path]] = None, verbose: bool = True
  43. ):
  44. """
  45. Initialize and return a SegmentationModel with specified configuration and weights.
  46. Args:
  47. cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
  48. weights (str | Path, optional): Path to pretrained weights file.
  49. verbose (bool): Whether to display model information during initialization.
  50. Returns:
  51. (SegmentationModel): Initialized segmentation model with loaded weights if specified.
  52. Examples:
  53. >>> trainer = SegmentationTrainer()
  54. >>> model = trainer.get_model(cfg="yolo11n-seg.yaml")
  55. >>> model = trainer.get_model(weights="yolo11n-seg.pt", verbose=False)
  56. """
  57. model = SegmentationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
  58. if weights:
  59. model.load(weights)
  60. return model
  61. def get_validator(self):
  62. """Return an instance of SegmentationValidator for validation of YOLO model."""
  63. self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
  64. return yolo.segment.SegmentationValidator(
  65. self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
  66. )
  67. def plot_training_samples(self, batch: Dict, ni: int):
  68. """
  69. Plot training sample images with labels, bounding boxes, and masks.
  70. This method creates a visualization of training batch images with their corresponding labels, bounding boxes,
  71. and segmentation masks, saving the result to a file for inspection and debugging.
  72. Args:
  73. batch (dict): Dictionary containing batch data with the following keys:
  74. 'img': Images tensor
  75. 'batch_idx': Batch indices for each box
  76. 'cls': Class labels tensor (squeezed to remove last dimension)
  77. 'bboxes': Bounding box coordinates tensor
  78. 'masks': Segmentation masks tensor
  79. 'im_file': List of image file paths
  80. ni (int): Current training iteration number, used for naming the output file.
  81. Examples:
  82. >>> trainer = SegmentationTrainer()
  83. >>> batch = {
  84. ... "img": torch.rand(16, 3, 640, 640),
  85. ... "batch_idx": torch.zeros(16),
  86. ... "cls": torch.randint(0, 80, (16, 1)),
  87. ... "bboxes": torch.rand(16, 4),
  88. ... "masks": torch.rand(16, 640, 640),
  89. ... "im_file": ["image1.jpg", "image2.jpg"],
  90. ... }
  91. >>> trainer.plot_training_samples(batch, ni=5)
  92. """
  93. plot_images(
  94. batch["img"],
  95. batch["batch_idx"],
  96. batch["cls"].squeeze(-1),
  97. batch["bboxes"],
  98. masks=batch["masks"],
  99. paths=batch["im_file"],
  100. fname=self.save_dir / f"train_batch{ni}.jpg",
  101. on_plot=self.on_plot,
  102. )
  103. def plot_metrics(self):
  104. """Plot training/validation metrics."""
  105. plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...