Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#20413 YOLOE: Fix visual prompt training

Merged
Ghost merged 1 commits into Ultralytics:main from ultralytics:yoloe-vp-fix
@@ -794,15 +794,16 @@ class TVPSegmentLoss(TVPDetectLoss):
 
 
     def __init__(self, model):
     def __init__(self, model):
         """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
         """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
+        super().__init__(model)
         self.vp_criterion = v8SegmentationLoss(model)
         self.vp_criterion = v8SegmentationLoss(model)
 
 
     def __call__(self, preds, batch):
     def __call__(self, preds, batch):
         """Calculate the loss for text-visual prompt segmentation."""
         """Calculate the loss for text-visual prompt segmentation."""
         feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
         feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
-        assert self.tp_criterion.reg_max == self.vp_criterion.reg_max
+        assert self.ori_reg_max == self.vp_criterion.reg_max  # TODO: remove it
 
 
-        if self.tp_criterion.reg_max * 4 + self.tp_criterion.nc == feats[0].shape[1]:
-            loss = torch.zeros(4, device=self.tp_criterion.device, requires_grad=True)
+        if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
+            loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
             return loss, loss.detach()
             return loss, loss.detach()
 
 
         vp_feats = self._get_vp_features(feats)
         vp_feats = self._get_vp_features(feats)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file