Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#609 Ci fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/infra-000_ci
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
  1. import torch
  2. from torch import nn, Tensor
  3. import torch.nn.functional as F
  4. class SEBlock(nn.Module):
  5. """
  6. Spatial Squeeze and Channel Excitation Block (cSE).
  7. Figure 1, Variant a from https://arxiv.org/abs/1808.08127v1
  8. """
  9. def __init__(self, in_channels: int, internal_neurons: int):
  10. super(SEBlock, self).__init__()
  11. self.down = nn.Conv2d(in_channels=in_channels, out_channels=internal_neurons, kernel_size=1, stride=1, bias=True)
  12. self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=in_channels, kernel_size=1, stride=1, bias=True)
  13. self.input_channels = in_channels
  14. def forward(self, inputs: Tensor) -> Tensor:
  15. x = F.avg_pool2d(inputs, kernel_size=inputs.size(3))
  16. x = self.down(x)
  17. x = F.relu(x)
  18. x = self.up(x)
  19. x = torch.sigmoid(x)
  20. x = x.view(-1, self.input_channels, 1, 1)
  21. return inputs * x
  22. class EffectiveSEBlock(nn.Module):
  23. """Effective Squeeze-Excitation
  24. From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
  25. """
  26. def __init__(self, in_channels: int):
  27. super().__init__()
  28. self.project = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0)
  29. self.act = nn.Hardsigmoid(inplace=True)
  30. def forward(self, x: Tensor) -> Tensor:
  31. x_se = x.mean((2, 3), keepdim=True)
  32. x_se = self.project(x_se)
  33. return x * self.act(x_se)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file