Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#609 Ci fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/infra-000_ci
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
  1. import torch.nn as nn
  2. from super_gradients.modules import ConvBNReLU
  3. class SegmentationHead(nn.Module):
  4. def __init__(self, in_channels: int, mid_channels: int, num_classes: int, dropout: float):
  5. super(SegmentationHead, self).__init__()
  6. self.seg_head = nn.Sequential(
  7. ConvBNReLU(in_channels, mid_channels, kernel_size=3, padding=1, stride=1, bias=False),
  8. nn.Dropout(dropout),
  9. nn.Conv2d(mid_channels, num_classes, kernel_size=1, bias=False),
  10. )
  11. def forward(self, x):
  12. return self.seg_head(x)
  13. def replace_num_classes(self, num_classes: int):
  14. """
  15. This method replace the last Conv Classification layer to output a different number of classes.
  16. Note that the weights of the new layers are random initiated.
  17. """
  18. old_cls_conv = self.seg_head[-1]
  19. self.seg_head[-1] = nn.Conv2d(old_cls_conv.in_channels, num_classes, kernel_size=1, bias=False)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file