Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#632 Fixed id_tensor registry, so reparametrization works when .cuda() is called

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:fix/qarepvgg-register-parameters
1 changed files with 20 additions and 8 deletions
  1. 20
    8
      src/super_gradients/modules/qarepvgg_block.py
@@ -1,4 +1,4 @@
-from typing import Type, Union, Mapping, Any
+from typing import Type, Union, Mapping, Any, Optional
 
 
 import torch
 import torch
 from torch import nn
 from torch import nn
@@ -64,7 +64,7 @@ class QARepVGGBlock(nn.Module):
         """
         """
         :param in_channels: Number of input channels
         :param in_channels: Number of input channels
         :param out_channels: Number of output channels
         :param out_channels: Number of output channels
-        :param activation_type: Type of the nonlinearity
+        :param activation_type: Type of the nonlinearity (nn.ReLU by default)
         :param se_type: Type of the se block (Use nn.Identity to disable SE)
         :param se_type: Type of the se block (Use nn.Identity to disable SE)
         :param stride: Output stride
         :param stride: Output stride
         :param dilation: Dilation factor for 3x3 conv
         :param dilation: Dilation factor for 3x3 conv
@@ -133,10 +133,16 @@ class QARepVGGBlock(nn.Module):
             self.identity = Residual()
             self.identity = Residual()
 
 
             input_dim = self.in_channels // self.groups
             input_dim = self.in_channels // self.groups
-            self.id_tensor = torch.zeros((self.in_channels, input_dim, 3, 3))
+            id_tensor = torch.zeros((self.in_channels, input_dim, 3, 3))
             for i in range(self.in_channels):
             for i in range(self.in_channels):
-                self.id_tensor[i, i % input_dim, 1, 1] = 1.0
-            self.id_tensor = self.id_tensor.to(dtype=self.branch_1x1.weight.dtype, device=self.branch_1x1.weight.device)
+                id_tensor[i, i % input_dim, 1, 1] = 1.0
+
+            self.id_tensor: Optional[torch.Tensor]
+            self.register_buffer(
+                name="id_tensor",
+                tensor=id_tensor.to(dtype=self.branch_1x1.weight.dtype, device=self.branch_1x1.weight.device),
+                persistent=False,  # so it's not saved in state_dict
+            )
         else:
         else:
             self.identity = None
             self.identity = None
 
 
@@ -234,7 +240,10 @@ class QARepVGGBlock(nn.Module):
         A = gamma / std
         A = gamma / std
         A_ = A.expand_as(kernel.transpose(0, -1)).transpose(0, -1)
         A_ = A.expand_as(kernel.transpose(0, -1)).transpose(0, -1)
 
 
-        return kernel * A_, bias * A + b
+        fused_kernel = kernel * A_
+        fused_bias = bias * A + b
+
+        return fused_kernel, fused_bias
 
 
     def full_fusion(self):
     def full_fusion(self):
         """Fuse everything into Conv-Act-SE, non-trainable, parameters detached
         """Fuse everything into Conv-Act-SE, non-trainable, parameters detached
@@ -299,7 +308,10 @@ class QARepVGGBlock(nn.Module):
         self.fully_fused = False
         self.fully_fused = False
 
 
     def fuse_block_residual_branches(self):
     def fuse_block_residual_branches(self):
-        self.full_fusion()
+        # inference frameworks will take care of resulting conv-bn-act-se
+        # no need to fuse post_bn prematurely if it is there
+        # call self.full_fusion() if you need it
+        self.partial_fusion()
 
 
-    def from_repvgg(self, repvgg_block: RepVGGBlock):
+    def from_repvgg(self, src: RepVGGBlock):
         raise NotImplementedError
         raise NotImplementedError
Discard