1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
|
- # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
- from copy import copy
- from pathlib import Path
- from typing import Dict, Optional, Union
- from ultralytics.models import yolo
- from ultralytics.nn.tasks import SegmentationModel
- from ultralytics.utils import DEFAULT_CFG, RANK
- from ultralytics.utils.plotting import plot_images, plot_results
- class SegmentationTrainer(yolo.detect.DetectionTrainer):
- """
- A class extending the DetectionTrainer class for training based on a segmentation model.
- This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
- functionality including model initialization, validation, and visualization.
- Attributes:
- loss_names (Tuple[str]): Names of the loss components used during training.
- Examples:
- >>> from ultralytics.models.yolo.segment import SegmentationTrainer
- >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
- >>> trainer = SegmentationTrainer(overrides=args)
- >>> trainer.train()
- """
- def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict] = None, _callbacks=None):
- """
- Initialize a SegmentationTrainer object.
- This initializes a trainer for segmentation tasks, extending the detection trainer with segmentation-specific
- functionality. It sets the task to 'segment' and prepares the trainer for training segmentation models.
- Args:
- cfg (dict): Configuration dictionary with default training settings.
- overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
- _callbacks (list, optional): List of callback functions to be executed during training.
- Examples:
- >>> from ultralytics.models.yolo.segment import SegmentationTrainer
- >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
- >>> trainer = SegmentationTrainer(overrides=args)
- >>> trainer.train()
- """
- if overrides is None:
- overrides = {}
- overrides["task"] = "segment"
- super().__init__(cfg, overrides, _callbacks)
- def get_model(
- self, cfg: Optional[Union[Dict, str]] = None, weights: Optional[Union[str, Path]] = None, verbose: bool = True
- ):
- """
- Initialize and return a SegmentationModel with specified configuration and weights.
- Args:
- cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
- weights (str | Path, optional): Path to pretrained weights file.
- verbose (bool): Whether to display model information during initialization.
- Returns:
- (SegmentationModel): Initialized segmentation model with loaded weights if specified.
- Examples:
- >>> trainer = SegmentationTrainer()
- >>> model = trainer.get_model(cfg="yolo11n-seg.yaml")
- >>> model = trainer.get_model(weights="yolo11n-seg.pt", verbose=False)
- """
- model = SegmentationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
- if weights:
- model.load(weights)
- return model
- def get_validator(self):
- """Return an instance of SegmentationValidator for validation of YOLO model."""
- self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
- return yolo.segment.SegmentationValidator(
- self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
- )
- def plot_training_samples(self, batch: Dict, ni: int):
- """
- Plot training sample images with labels, bounding boxes, and masks.
- This method creates a visualization of training batch images with their corresponding labels, bounding boxes,
- and segmentation masks, saving the result to a file for inspection and debugging.
- Args:
- batch (dict): Dictionary containing batch data with the following keys:
- 'img': Images tensor
- 'batch_idx': Batch indices for each box
- 'cls': Class labels tensor (squeezed to remove last dimension)
- 'bboxes': Bounding box coordinates tensor
- 'masks': Segmentation masks tensor
- 'im_file': List of image file paths
- ni (int): Current training iteration number, used for naming the output file.
- Examples:
- >>> trainer = SegmentationTrainer()
- >>> batch = {
- ... "img": torch.rand(16, 3, 640, 640),
- ... "batch_idx": torch.zeros(16),
- ... "cls": torch.randint(0, 80, (16, 1)),
- ... "bboxes": torch.rand(16, 4),
- ... "masks": torch.rand(16, 640, 640),
- ... "im_file": ["image1.jpg", "image2.jpg"],
- ... }
- >>> trainer.plot_training_samples(batch, ni=5)
- """
- plot_images(
- batch["img"],
- batch["batch_idx"],
- batch["cls"].squeeze(-1),
- batch["bboxes"],
- masks=batch["masks"],
- paths=batch["im_file"],
- fname=self.save_dir / f"train_batch{ni}.jpg",
- on_plot=self.on_plot,
- )
- def plot_metrics(self):
- """Plot training/validation metrics."""
- plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
|