Thank you! We'll be in touch ASAP.
Something went wrong, please try again or contact us directly at contact@dagshub.com
Deci-AI:master
deci-ai:bugfix/infra-000_ci
import torch from torch import nn from torch.nn.modules.loss import _Loss class FocalLoss(_Loss): """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)""" def __init__(self, loss_fcn: nn.BCEWithLogitsLoss, gamma=1.5, alpha=0.25): super(FocalLoss, self).__init__() self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() self.gamma = gamma self.alpha = alpha self.reduction = loss_fcn.reduction self.loss_fcn.reduction = 'none' # required to apply FocalLoss to each element def forward(self, pred, true): loss = self.loss_fcn(pred, true) pred_prob = torch.sigmoid(pred) # prob from logits p_t = true * pred_prob + (1 - true) * (1 - pred_prob) alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) modulating_factor = (1.0 - p_t) ** self.gamma loss *= alpha_factor * modulating_factor if self.reduction == 'mean': return loss.mean() elif self.reduction == 'sum': return loss.sum() else: # 'none' return loss
Press p or to see the previous file or, n or to see the next file