|
@@ -794,15 +794,16 @@ class TVPSegmentLoss(TVPDetectLoss):
|
|
|
|
|
|
def __init__(self, model):
|
|
def __init__(self, model):
|
|
"""Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
|
|
"""Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
|
|
|
|
+ super().__init__(model)
|
|
self.vp_criterion = v8SegmentationLoss(model)
|
|
self.vp_criterion = v8SegmentationLoss(model)
|
|
|
|
|
|
def __call__(self, preds, batch):
|
|
def __call__(self, preds, batch):
|
|
"""Calculate the loss for text-visual prompt segmentation."""
|
|
"""Calculate the loss for text-visual prompt segmentation."""
|
|
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
|
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
|
- assert self.tp_criterion.reg_max == self.vp_criterion.reg_max
|
|
|
|
|
|
+ assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
|
|
|
|
|
- if self.tp_criterion.reg_max * 4 + self.tp_criterion.nc == feats[0].shape[1]:
|
|
|
|
- loss = torch.zeros(4, device=self.tp_criterion.device, requires_grad=True)
|
|
|
|
|
|
+ if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
|
|
|
|
+ loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
|
|
return loss, loss.detach()
|
|
return loss, loss.detach()
|
|
|
|
|
|
vp_feats = self._get_vp_features(feats)
|
|
vp_feats = self._get_vp_features(feats)
|