Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#20413 YOLOE: Fix visual prompt training

Merged
Ghost merged 1 commits into Ultralytics:main from ultralytics:yoloe-vp-fix
@@ -443,7 +443,7 @@ class GroundingDataset(YOLODataset):
         """
         """
         assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks"
         assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks"
         self.json_file = json_file
         self.json_file = json_file
-        super().__init__(*args, task=task, data={}, **kwargs)
+        super().__init__(*args, task=task, data={"channels": 3}, **kwargs)
 
 
     def get_img_files(self, img_path):
     def get_img_files(self, img_path):
         """
         """
Discard
@@ -291,8 +291,9 @@ class YOLOETrainerFromScratch(YOLOETrainer):
         # NOTE: to make training work properly, set `nc` and `names`
         # NOTE: to make training work properly, set `nc` and `names`
         final_data["nc"] = data["val"][0]["nc"]
         final_data["nc"] = data["val"][0]["nc"]
         final_data["names"] = data["val"][0]["names"]
         final_data["names"] = data["val"][0]["names"]
-        # NOTE: add path with lvis path
+        # NOTE: add path with lvis path and image channels
         final_data["path"] = data["val"][0]["path"]
         final_data["path"] = data["val"][0]["path"]
+        final_data["channels"] = data["val"][0]["channels"]
         self.data = final_data
         self.data = final_data
         if self.args.single_cls:  # consistent with base trainer
         if self.args.single_cls:  # consistent with base trainer
             LOGGER.info("Overriding class names with single class.")
             LOGGER.info("Overriding class names with single class.")
Discard
@@ -794,15 +794,16 @@ class TVPSegmentLoss(TVPDetectLoss):
 
 
     def __init__(self, model):
     def __init__(self, model):
         """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
         """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
+        super().__init__(model)
         self.vp_criterion = v8SegmentationLoss(model)
         self.vp_criterion = v8SegmentationLoss(model)
 
 
     def __call__(self, preds, batch):
     def __call__(self, preds, batch):
         """Calculate the loss for text-visual prompt segmentation."""
         """Calculate the loss for text-visual prompt segmentation."""
         feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
         feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
-        assert self.tp_criterion.reg_max == self.vp_criterion.reg_max
+        assert self.ori_reg_max == self.vp_criterion.reg_max  # TODO: remove it
 
 
-        if self.tp_criterion.reg_max * 4 + self.tp_criterion.nc == feats[0].shape[1]:
-            loss = torch.zeros(4, device=self.tp_criterion.device, requires_grad=True)
+        if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
+            loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
             return loss, loss.detach()
             return loss, loss.detach()
 
 
         vp_feats = self._get_vp_features(feats)
         vp_feats = self._get_vp_features(feats)
Discard