Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#20413 YOLOE: Fix visual prompt training

Merged
Ghost merged 1 commits into Ultralytics:main from ultralytics:yoloe-vp-fix
@@ -150,6 +150,7 @@ wandb/
 weights/
 weights/
 *.weights
 *.weights
 *.pt
 *.pt
+*.ts
 *.pb
 *.pb
 *.onnx
 *.onnx
 *.engine
 *.engine
Discard
@@ -462,7 +462,7 @@ Model validation on a dataset is streamlined as follows:
 
 
         ```python
         ```python
         from ultralytics import YOLOE
         from ultralytics import YOLOE
-        from ultralytics.models.yolo.yoloe import YOLOEVPTrainer
+        from ultralytics.models.yolo.yoloe import YOLOESegVPTrainer
 
 
         data = dict(
         data = dict(
             train=dict(
             train=dict(
@@ -503,7 +503,7 @@ Model validation on a dataset is streamlined as follows:
             weight_decay=0.025,
             weight_decay=0.025,
             momentum=0.9,
             momentum=0.9,
             workers=4,
             workers=4,
-            trainer=YOLOEVPTrainer,
+            trainer=YOLOESegVPTrainer,  # use YOLOEVPTrainer if converted to detection model
             device="0,1,2,3,4,5,6,7",
             device="0,1,2,3,4,5,6,7",
             freeze=freeze,
             freeze=freeze,
         )
         )
Discard
@@ -443,7 +443,7 @@ class GroundingDataset(YOLODataset):
         """
         """
         assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks"
         assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks"
         self.json_file = json_file
         self.json_file = json_file
-        super().__init__(*args, task=task, data={}, **kwargs)
+        super().__init__(*args, task=task, data={"channels": 3}, **kwargs)
 
 
     def get_img_files(self, img_path):
     def get_img_files(self, img_path):
         """
         """
Discard
@@ -291,8 +291,9 @@ class YOLOETrainerFromScratch(YOLOETrainer):
         # NOTE: to make training work properly, set `nc` and `names`
         # NOTE: to make training work properly, set `nc` and `names`
         final_data["nc"] = data["val"][0]["nc"]
         final_data["nc"] = data["val"][0]["nc"]
         final_data["names"] = data["val"][0]["names"]
         final_data["names"] = data["val"][0]["names"]
-        # NOTE: add path with lvis path
+        # NOTE: add path with lvis path and image channels
         final_data["path"] = data["val"][0]["path"]
         final_data["path"] = data["val"][0]["path"]
+        final_data["channels"] = data["val"][0]["channels"]
         self.data = final_data
         self.data = final_data
         if self.args.single_cls:  # consistent with base trainer
         if self.args.single_cls:  # consistent with base trainer
             LOGGER.info("Overriding class names with single class.")
             LOGGER.info("Overriding class names with single class.")
Discard
@@ -794,15 +794,16 @@ class TVPSegmentLoss(TVPDetectLoss):
 
 
     def __init__(self, model):
     def __init__(self, model):
         """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
         """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
+        super().__init__(model)
         self.vp_criterion = v8SegmentationLoss(model)
         self.vp_criterion = v8SegmentationLoss(model)
 
 
     def __call__(self, preds, batch):
     def __call__(self, preds, batch):
         """Calculate the loss for text-visual prompt segmentation."""
         """Calculate the loss for text-visual prompt segmentation."""
         feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
         feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
-        assert self.tp_criterion.reg_max == self.vp_criterion.reg_max
+        assert self.ori_reg_max == self.vp_criterion.reg_max  # TODO: remove it
 
 
-        if self.tp_criterion.reg_max * 4 + self.tp_criterion.nc == feats[0].shape[1]:
-            loss = torch.zeros(4, device=self.tp_criterion.device, requires_grad=True)
+        if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
+            loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
             return loss, loss.detach()
             return loss, loss.detach()
 
 
         vp_feats = self._get_vp_features(feats)
         vp_feats = self._get_vp_features(feats)
Discard