Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

context_modules.py 5.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
  1. from typing import Type, List, Tuple, Union, Dict
  2. from abc import ABC, abstractmethod
  3. import torch.nn as nn
  4. import torch
  5. import torch.nn.functional as F
  6. from super_gradients.modules import ConvBNReLU
  7. from super_gradients.training.utils.module_utils import UpsampleMode
  8. from super_gradients.common.object_names import ContextModules
  9. class AbstractContextModule(nn.Module, ABC):
  10. @abstractmethod
  11. def output_channels(self):
  12. raise NotImplementedError
  13. class SPPM(AbstractContextModule):
  14. """
  15. Simple Pyramid Pooling context Module.
  16. """
  17. def __init__(
  18. self,
  19. in_channels: int,
  20. inter_channels: int,
  21. out_channels: int,
  22. pool_sizes: List[Union[int, Tuple[int, int]]],
  23. upsample_mode: Union[UpsampleMode, str] = UpsampleMode.BILINEAR,
  24. align_corners: bool = False,
  25. ):
  26. """
  27. :param inter_channels: num channels in each pooling branch.
  28. :param out_channels: The number of output channels after pyramid pooling module.
  29. :param pool_sizes: spatial output sizes of the pooled feature maps.
  30. """
  31. super().__init__()
  32. self.branches = nn.ModuleList(
  33. [
  34. nn.Sequential(
  35. nn.AdaptiveAvgPool2d(pool_size),
  36. ConvBNReLU(in_channels, inter_channels, kernel_size=1, bias=False),
  37. )
  38. for pool_size in pool_sizes
  39. ]
  40. )
  41. self.conv_out = ConvBNReLU(inter_channels, out_channels, kernel_size=3, padding=1, bias=False)
  42. self.out_channels = out_channels
  43. self.upsample_mode = upsample_mode
  44. self.align_corners = align_corners
  45. self.pool_sizes = pool_sizes
  46. def forward(self, x):
  47. out = None
  48. input_shape = x.shape[2:]
  49. for branch in self.branches:
  50. y = branch(x)
  51. y = F.interpolate(y, size=input_shape, mode=self.upsample_mode, align_corners=self.align_corners)
  52. out = y if out is None else out + y
  53. out = self.conv_out(out)
  54. return out
  55. def output_channels(self):
  56. return self.out_channels
  57. def prep_model_for_conversion(self, input_size: Union[tuple, list], stride_ratio: int = 32, **kwargs):
  58. """
  59. Replace Global average pooling with fixed kernels Average pooling, since dynamic kernel sizes are not supported
  60. when compiling to ONNX: `Unsupported: ONNX export of operator adaptive_avg_pool2d, input size not accessible.`
  61. """
  62. input_size = [x / stride_ratio for x in input_size[-2:]]
  63. for branch in self.branches:
  64. global_pool: nn.AdaptiveAvgPool2d = branch[0]
  65. # If not a global average pooling skip this. The module might be already converted to average pooling
  66. # modules.
  67. if not isinstance(global_pool, nn.AdaptiveAvgPool2d):
  68. continue
  69. out_size = global_pool.output_size
  70. out_size = out_size if isinstance(out_size, (tuple, list)) else (out_size, out_size)
  71. kernel_size = [int(i / o) for i, o in zip(input_size, out_size)]
  72. branch[0] = nn.AvgPool2d(kernel_size=kernel_size, stride=kernel_size)
  73. class ASPP(AbstractContextModule):
  74. """
  75. ASPP bottleneck block. Splits the input to len(dilation_list) + 1, (a 1x1 conv) heads of differently dilated convolutions.
  76. The different heads will be concatenated and the output channel of each will be the
  77. input channel / len(dilation_list) + 1 so as to keep the same output channel as input channel.
  78. """
  79. def __init__(self, in_channels: int, dilation_list: List[int], in_out_ratio: float = 1.0, use_bias: bool = False, **kwargs):
  80. """
  81. :param dilation_list: list of dilation rates, the num of dilation branches should be set so that there is a
  82. whole division of the input channels, see assertion below.
  83. :param in_out_ratio: output / input num of channels ratio.
  84. :param use_bias: legacy parameter to support PascalVOC frontier checkpoints that were trained by mistake with
  85. extra redundant biases before batchnorm operators. should be set to `False` for new training processes.
  86. """
  87. super().__init__()
  88. num_dilation_branches = len(dilation_list) + 1
  89. inter_ratio = num_dilation_branches / in_out_ratio
  90. assert in_channels % inter_ratio == 0
  91. inter_channels = int(in_channels / inter_ratio)
  92. self.dilated_conv_list = nn.ModuleList(
  93. [
  94. ConvBNReLU(in_channels, inter_channels, kernel_size=1, dilation=1, bias=use_bias),
  95. *[ConvBNReLU(in_channels, inter_channels, kernel_size=3, dilation=d, padding=d, bias=use_bias) for d in dilation_list],
  96. ]
  97. )
  98. self.out_channels = inter_channels * num_dilation_branches
  99. def output_channels(self):
  100. return self.out_channels
  101. def forward(self, x):
  102. x = torch.cat([dilated_conv(x) for dilated_conv in self.dilated_conv_list], dim=1)
  103. return x
  104. CONTEXT_TYPE_DICT: Dict[str, Type[AbstractContextModule]] = {ContextModules.ASPP: ASPP, ContextModules.SPPM: SPPM}
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...