|
@@ -11,212 +11,13 @@ Based on https://github.com/DingXiaoH/RepVGG
|
|
from typing import Union
|
|
from typing import Union
|
|
|
|
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
-import numpy as np
|
|
|
|
-import torch
|
|
|
|
-import torch.nn.parallel
|
|
|
|
-import torch.optim
|
|
|
|
-import torch.utils.data
|
|
|
|
-import torch.utils.data.distributed
|
|
|
|
|
|
+
|
|
|
|
+from super_gradients.modules import RepVGGBlock, SEBlock
|
|
from super_gradients.training.models.sg_module import SgModule
|
|
from super_gradients.training.models.sg_module import SgModule
|
|
-import torch.nn.functional as F
|
|
|
|
from super_gradients.training.utils.module_utils import fuse_repvgg_blocks_residual_branches
|
|
from super_gradients.training.utils.module_utils import fuse_repvgg_blocks_residual_branches
|
|
from super_gradients.training.utils.utils import get_param
|
|
from super_gradients.training.utils.utils import get_param
|
|
|
|
|
|
|
|
|
|
-class SEBlock(nn.Module):
|
|
|
|
- def __init__(self, input_channels, internal_neurons):
|
|
|
|
- super(SEBlock, self).__init__()
|
|
|
|
- self.down = nn.Conv2d(
|
|
|
|
- in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1, bias=True
|
|
|
|
- )
|
|
|
|
- self.up = nn.Conv2d(
|
|
|
|
- in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1, bias=True
|
|
|
|
- )
|
|
|
|
- self.input_channels = input_channels
|
|
|
|
-
|
|
|
|
- def forward(self, inputs):
|
|
|
|
- x = F.avg_pool2d(inputs, kernel_size=inputs.size(3))
|
|
|
|
- x = self.down(x)
|
|
|
|
- x = F.relu(x)
|
|
|
|
- x = self.up(x)
|
|
|
|
- x = torch.sigmoid(x)
|
|
|
|
- x = x.view(-1, self.input_channels, 1, 1)
|
|
|
|
- return inputs * x
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1, dilation=1):
|
|
|
|
- result = nn.Sequential()
|
|
|
|
- result.add_module(
|
|
|
|
- "conv",
|
|
|
|
- nn.Conv2d(
|
|
|
|
- in_channels=in_channels,
|
|
|
|
- out_channels=out_channels,
|
|
|
|
- kernel_size=kernel_size,
|
|
|
|
- stride=stride,
|
|
|
|
- padding=padding,
|
|
|
|
- groups=groups,
|
|
|
|
- bias=False,
|
|
|
|
- dilation=dilation,
|
|
|
|
- ),
|
|
|
|
- )
|
|
|
|
- result.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
|
|
|
|
- return result
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-class RepVGGBlock(nn.Module):
|
|
|
|
- """
|
|
|
|
- Repvgg block consists of three branches
|
|
|
|
- 3x3: a branch of a 3x3 convolution + batchnorm + relu
|
|
|
|
- 1x1: a branch of a 1x1 convolution + batchnorm + relu
|
|
|
|
- no_conv_branch: a branch with only batchnorm which will only be used if input channel == output channel
|
|
|
|
- (usually in all but the first block of each stage)
|
|
|
|
- """
|
|
|
|
-
|
|
|
|
- def __init__(
|
|
|
|
- self,
|
|
|
|
- in_channels,
|
|
|
|
- out_channels,
|
|
|
|
- kernel_size,
|
|
|
|
- stride=1,
|
|
|
|
- padding=0,
|
|
|
|
- dilation=1,
|
|
|
|
- groups=1,
|
|
|
|
- build_residual_branches=True,
|
|
|
|
- use_relu=True,
|
|
|
|
- use_se=False,
|
|
|
|
- ):
|
|
|
|
-
|
|
|
|
- super(RepVGGBlock, self).__init__()
|
|
|
|
-
|
|
|
|
- self.groups = groups
|
|
|
|
- self.in_channels = in_channels
|
|
|
|
-
|
|
|
|
- assert kernel_size == 3
|
|
|
|
- assert padding == dilation
|
|
|
|
-
|
|
|
|
- self.nonlinearity = nn.ReLU() if use_relu else nn.Identity()
|
|
|
|
- self.se = nn.Identity() if not use_se else SEBlock(out_channels, internal_neurons=out_channels // 16)
|
|
|
|
-
|
|
|
|
- self.no_conv_branch = (
|
|
|
|
- nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
|
|
|
|
- )
|
|
|
|
- self.branch_3x3 = conv_bn(
|
|
|
|
- in_channels=in_channels,
|
|
|
|
- out_channels=out_channels,
|
|
|
|
- dilation=dilation,
|
|
|
|
- kernel_size=kernel_size,
|
|
|
|
- stride=stride,
|
|
|
|
- padding=padding,
|
|
|
|
- groups=groups,
|
|
|
|
- )
|
|
|
|
- self.branch_1x1 = conv_bn(
|
|
|
|
- in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=0, groups=groups
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- if not build_residual_branches:
|
|
|
|
- self.fuse_block_residual_branches()
|
|
|
|
- else:
|
|
|
|
- self.build_residual_branches = True
|
|
|
|
-
|
|
|
|
- def forward(self, inputs):
|
|
|
|
- if not self.build_residual_branches:
|
|
|
|
- return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
|
|
|
|
-
|
|
|
|
- if self.no_conv_branch is None:
|
|
|
|
- id_out = 0
|
|
|
|
- else:
|
|
|
|
- id_out = self.no_conv_branch(inputs)
|
|
|
|
-
|
|
|
|
- return self.nonlinearity(self.se(self.branch_3x3(inputs) + self.branch_1x1(inputs) + id_out))
|
|
|
|
-
|
|
|
|
- def _get_equivalent_kernel_bias(self):
|
|
|
|
- """
|
|
|
|
- Fuses the 3x3, 1x1 and identity branches into a single 3x3 conv layer
|
|
|
|
- """
|
|
|
|
- kernel3x3, bias3x3 = self._fuse_bn_tensor(self.branch_3x3)
|
|
|
|
- kernel1x1, bias1x1 = self._fuse_bn_tensor(self.branch_1x1)
|
|
|
|
- kernelid, biasid = self._fuse_bn_tensor(self.no_conv_branch)
|
|
|
|
- return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
|
|
|
|
-
|
|
|
|
- def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
|
|
|
- """
|
|
|
|
- padding the 1x1 convolution weights with zeros to be able to fuse the 3x3 conv layer with the 1x1
|
|
|
|
- :param kernel1x1: weights of the 1x1 convolution
|
|
|
|
- :type kernel1x1:
|
|
|
|
- :return: padded 1x1 weights
|
|
|
|
- :rtype:
|
|
|
|
- """
|
|
|
|
- if kernel1x1 is None:
|
|
|
|
- return 0
|
|
|
|
- else:
|
|
|
|
- return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
|
|
|
|
-
|
|
|
|
- def _fuse_bn_tensor(self, branch):
|
|
|
|
- """
|
|
|
|
- Fusing of the batchnorm into the conv layer.
|
|
|
|
- If the branch is the identity branch (no conv) the kernel will simply be eye.
|
|
|
|
- :param branch:
|
|
|
|
- :type branch:
|
|
|
|
- :return:
|
|
|
|
- :rtype:
|
|
|
|
- """
|
|
|
|
- if branch is None:
|
|
|
|
- return 0, 0
|
|
|
|
- if isinstance(branch, nn.Sequential):
|
|
|
|
- kernel = branch.conv.weight
|
|
|
|
- running_mean = branch.bn.running_mean
|
|
|
|
- running_var = branch.bn.running_var
|
|
|
|
- gamma = branch.bn.weight
|
|
|
|
- beta = branch.bn.bias
|
|
|
|
- eps = branch.bn.eps
|
|
|
|
- else:
|
|
|
|
- assert isinstance(branch, nn.BatchNorm2d)
|
|
|
|
- if not hasattr(self, "id_tensor"):
|
|
|
|
- input_dim = self.in_channels // self.groups
|
|
|
|
- kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
|
|
|
|
- for i in range(self.in_channels):
|
|
|
|
- kernel_value[i, i % input_dim, 1, 1] = 1
|
|
|
|
- self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
|
|
|
|
- kernel = self.id_tensor
|
|
|
|
- running_mean = branch.running_mean
|
|
|
|
- running_var = branch.running_var
|
|
|
|
- gamma = branch.weight
|
|
|
|
- beta = branch.bias
|
|
|
|
- eps = branch.eps
|
|
|
|
- std = (running_var + eps).sqrt()
|
|
|
|
- t = (gamma / std).reshape(-1, 1, 1, 1)
|
|
|
|
- return kernel * t, beta - running_mean * gamma / std
|
|
|
|
-
|
|
|
|
- def fuse_block_residual_branches(self):
|
|
|
|
- """
|
|
|
|
- converts a repvgg block from training model (with branches) to deployment mode (vgg like model)
|
|
|
|
- :return:
|
|
|
|
- :rtype:
|
|
|
|
- """
|
|
|
|
- if hasattr(self, "build_residual_branches") and not self.build_residual_branches:
|
|
|
|
- return
|
|
|
|
- kernel, bias = self._get_equivalent_kernel_bias()
|
|
|
|
- self.rbr_reparam = nn.Conv2d(
|
|
|
|
- in_channels=self.branch_3x3.conv.in_channels,
|
|
|
|
- out_channels=self.branch_3x3.conv.out_channels,
|
|
|
|
- kernel_size=self.branch_3x3.conv.kernel_size,
|
|
|
|
- stride=self.branch_3x3.conv.stride,
|
|
|
|
- padding=self.branch_3x3.conv.padding,
|
|
|
|
- dilation=self.branch_3x3.conv.dilation,
|
|
|
|
- groups=self.branch_3x3.conv.groups,
|
|
|
|
- bias=True,
|
|
|
|
- )
|
|
|
|
- self.rbr_reparam.weight.data = kernel
|
|
|
|
- self.rbr_reparam.bias.data = bias
|
|
|
|
- for para in self.parameters():
|
|
|
|
- para.detach_()
|
|
|
|
- self.__delattr__("branch_3x3")
|
|
|
|
- self.__delattr__("branch_1x1")
|
|
|
|
- if hasattr(self, "no_conv_branch"):
|
|
|
|
- self.__delattr__("no_conv_branch")
|
|
|
|
- self.build_residual_branches = False
|
|
|
|
-
|
|
|
|
-
|
|
|
|
class RepVGG(SgModule):
|
|
class RepVGG(SgModule):
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
@@ -253,11 +54,12 @@ class RepVGG(SgModule):
|
|
self.stem = RepVGGBlock(
|
|
self.stem = RepVGGBlock(
|
|
in_channels=in_channels,
|
|
in_channels=in_channels,
|
|
out_channels=self.in_planes,
|
|
out_channels=self.in_planes,
|
|
- kernel_size=3,
|
|
|
|
stride=2,
|
|
stride=2,
|
|
- padding=1,
|
|
|
|
build_residual_branches=build_residual_branches,
|
|
build_residual_branches=build_residual_branches,
|
|
- use_se=self.use_se,
|
|
|
|
|
|
+ activation_type=nn.ReLU,
|
|
|
|
+ activation_kwargs=dict(inplace=True),
|
|
|
|
+ se_type=SEBlock if self.use_se else nn.Identity,
|
|
|
|
+ se_kwargs=dict(in_channels=self.in_planes, internal_neurons=self.in_planes // 16) if self.use_se else None,
|
|
)
|
|
)
|
|
self.cur_layer_idx = 1
|
|
self.cur_layer_idx = 1
|
|
self.stage1 = self._make_stage(int(64 * width_multiplier[0]), struct[0], stride=2)
|
|
self.stage1 = self._make_stage(int(64 * width_multiplier[0]), struct[0], stride=2)
|
|
@@ -282,12 +84,13 @@ class RepVGG(SgModule):
|
|
RepVGGBlock(
|
|
RepVGGBlock(
|
|
in_channels=self.in_planes,
|
|
in_channels=self.in_planes,
|
|
out_channels=planes,
|
|
out_channels=planes,
|
|
- kernel_size=3,
|
|
|
|
stride=stride,
|
|
stride=stride,
|
|
- padding=1,
|
|
|
|
groups=1,
|
|
groups=1,
|
|
build_residual_branches=self.build_residual_branches,
|
|
build_residual_branches=self.build_residual_branches,
|
|
- use_se=self.use_se,
|
|
|
|
|
|
+ activation_type=nn.ReLU,
|
|
|
|
+ activation_kwargs=dict(inplace=True),
|
|
|
|
+ se_type=SEBlock if self.use_se else nn.Identity,
|
|
|
|
+ se_kwargs=dict(in_channels=self.in_planes, internal_neurons=self.in_planes // 16) if self.use_se else None,
|
|
)
|
|
)
|
|
)
|
|
)
|
|
self.in_planes = planes
|
|
self.in_planes = planes
|
|
@@ -312,10 +115,9 @@ class RepVGG(SgModule):
|
|
|
|
|
|
def train(self, mode: bool = True):
|
|
def train(self, mode: bool = True):
|
|
|
|
|
|
- assert not mode or self.build_residual_branches, (
|
|
|
|
- "Trying to train a model without residual branches, "
|
|
|
|
- "set arch_params.build_residual_branches to True and retrain the model"
|
|
|
|
- )
|
|
|
|
|
|
+ assert (
|
|
|
|
+ not mode or self.build_residual_branches
|
|
|
|
+ ), "Trying to train a model without residual branches, set arch_params.build_residual_branches to True and retrain the model"
|
|
super(RepVGG, self).train(mode=mode)
|
|
super(RepVGG, self).train(mode=mode)
|
|
|
|
|
|
def replace_head(self, new_num_classes=None, new_head=None):
|
|
def replace_head(self, new_num_classes=None, new_head=None):
|