Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#578 Feature/sg 516 support head replacement for local pretrained weights unknown dataset

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-516_support_head_replacement_for_local_pretrained_weights_unknown_dataset
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
  1. import torch
  2. from super_gradients.training.losses.bce_loss import BCE
  3. from super_gradients.training.losses.dice_loss import BinaryDiceLoss
  4. class BCEDiceLoss(torch.nn.Module):
  5. """
  6. Binary Cross Entropy + Dice Loss
  7. Weighted average of BCE and Dice loss
  8. Attributes:
  9. loss_weights: list of size 2 s.t loss_weights[0], loss_weights[1] are the weights for BCE, Dice
  10. respectively.
  11. """
  12. def __init__(self, loss_weights=[0.5, 0.5], logits=True):
  13. super(BCEDiceLoss, self).__init__()
  14. self.loss_weights = loss_weights
  15. self.bce = BCE()
  16. self.dice = BinaryDiceLoss(apply_sigmoid=logits)
  17. def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
  18. """
  19. @param input: Network's raw output shaped (N,1,H,W)
  20. @param target: Ground truth shaped (N,H,W)
  21. """
  22. return self.loss_weights[0] * self.bce(input, target) + self.loss_weights[1] * self.dice(input, target)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file