|
@@ -1,4 +1,4 @@
|
|
-from typing import Type, Union, Mapping, Any
|
|
|
|
|
|
+from typing import Type, Union, Mapping, Any, Optional
|
|
|
|
|
|
import torch
|
|
import torch
|
|
from torch import nn
|
|
from torch import nn
|
|
@@ -64,7 +64,7 @@ class QARepVGGBlock(nn.Module):
|
|
"""
|
|
"""
|
|
:param in_channels: Number of input channels
|
|
:param in_channels: Number of input channels
|
|
:param out_channels: Number of output channels
|
|
:param out_channels: Number of output channels
|
|
- :param activation_type: Type of the nonlinearity
|
|
|
|
|
|
+ :param activation_type: Type of the nonlinearity (nn.ReLU by default)
|
|
:param se_type: Type of the se block (Use nn.Identity to disable SE)
|
|
:param se_type: Type of the se block (Use nn.Identity to disable SE)
|
|
:param stride: Output stride
|
|
:param stride: Output stride
|
|
:param dilation: Dilation factor for 3x3 conv
|
|
:param dilation: Dilation factor for 3x3 conv
|
|
@@ -133,10 +133,16 @@ class QARepVGGBlock(nn.Module):
|
|
self.identity = Residual()
|
|
self.identity = Residual()
|
|
|
|
|
|
input_dim = self.in_channels // self.groups
|
|
input_dim = self.in_channels // self.groups
|
|
- self.id_tensor = torch.zeros((self.in_channels, input_dim, 3, 3))
|
|
|
|
|
|
+ id_tensor = torch.zeros((self.in_channels, input_dim, 3, 3))
|
|
for i in range(self.in_channels):
|
|
for i in range(self.in_channels):
|
|
- self.id_tensor[i, i % input_dim, 1, 1] = 1.0
|
|
|
|
- self.id_tensor = self.id_tensor.to(dtype=self.branch_1x1.weight.dtype, device=self.branch_1x1.weight.device)
|
|
|
|
|
|
+ id_tensor[i, i % input_dim, 1, 1] = 1.0
|
|
|
|
+
|
|
|
|
+ self.id_tensor: Optional[torch.Tensor]
|
|
|
|
+ self.register_buffer(
|
|
|
|
+ name="id_tensor",
|
|
|
|
+ tensor=id_tensor.to(dtype=self.branch_1x1.weight.dtype, device=self.branch_1x1.weight.device),
|
|
|
|
+ persistent=False, # so it's not saved in state_dict
|
|
|
|
+ )
|
|
else:
|
|
else:
|
|
self.identity = None
|
|
self.identity = None
|
|
|
|
|
|
@@ -234,7 +240,10 @@ class QARepVGGBlock(nn.Module):
|
|
A = gamma / std
|
|
A = gamma / std
|
|
A_ = A.expand_as(kernel.transpose(0, -1)).transpose(0, -1)
|
|
A_ = A.expand_as(kernel.transpose(0, -1)).transpose(0, -1)
|
|
|
|
|
|
- return kernel * A_, bias * A + b
|
|
|
|
|
|
+ fused_kernel = kernel * A_
|
|
|
|
+ fused_bias = bias * A + b
|
|
|
|
+
|
|
|
|
+ return fused_kernel, fused_bias
|
|
|
|
|
|
def full_fusion(self):
|
|
def full_fusion(self):
|
|
"""Fuse everything into Conv-Act-SE, non-trainable, parameters detached
|
|
"""Fuse everything into Conv-Act-SE, non-trainable, parameters detached
|
|
@@ -299,7 +308,10 @@ class QARepVGGBlock(nn.Module):
|
|
self.fully_fused = False
|
|
self.fully_fused = False
|
|
|
|
|
|
def fuse_block_residual_branches(self):
|
|
def fuse_block_residual_branches(self):
|
|
- self.full_fusion()
|
|
|
|
|
|
+ # inference frameworks will take care of resulting conv-bn-act-se
|
|
|
|
+ # no need to fuse post_bn prematurely if it is there
|
|
|
|
+ # call self.full_fusion() if you need it
|
|
|
|
+ self.partial_fusion()
|
|
|
|
|
|
- def from_repvgg(self, repvgg_block: RepVGGBlock):
|
|
|
|
|
|
+ def from_repvgg(self, src: RepVGGBlock):
|
|
raise NotImplementedError
|
|
raise NotImplementedError
|