Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

conv_bn_relu_test.py 3.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  1. import torch
  2. import unittest
  3. import torch.nn as nn
  4. from super_gradients.training.utils.module_utils import ConvBNReLU
  5. class TestConvBnRelu(unittest.TestCase):
  6. def setUp(self) -> None:
  7. self.sample = torch.randn(2, 32, 64, 64)
  8. self.test_kernels = [1, 3, 5]
  9. self.test_strides = [1, 2]
  10. self.use_activation = [True, False]
  11. self.use_normalization = [True, False]
  12. self.biases = [True, False]
  13. def test_conv_bn_relu(self):
  14. for use_normalization in self.use_normalization:
  15. for use_activation in self.use_activation:
  16. for kernel in self.test_kernels:
  17. for stride in self.test_strides:
  18. for bias in self.biases:
  19. conv_bn_relu = ConvBNReLU(
  20. 32,
  21. 32,
  22. kernel_size=kernel,
  23. stride=stride,
  24. padding=kernel // 2,
  25. bias=bias,
  26. use_activation=use_activation,
  27. use_normalization=use_normalization,
  28. )
  29. conv_bn_relu_seq = nn.Sequential(
  30. nn.Conv2d(32, 32, kernel_size=kernel, stride=stride, padding=kernel // 2, bias=bias),
  31. nn.BatchNorm2d(32) if use_normalization else nn.Identity(),
  32. nn.ReLU() if use_activation else nn.Identity(),
  33. )
  34. # apply same conv weights and biases to compare output,
  35. # because conv weight and biases have random initialization.
  36. conv_bn_relu.seq[0].weight = conv_bn_relu_seq[0].weight
  37. if bias:
  38. conv_bn_relu.seq[0].bias = conv_bn_relu_seq[0].bias
  39. self.assertTrue(
  40. torch.equal(conv_bn_relu(self.sample), conv_bn_relu_seq(self.sample)),
  41. msg=f"ConvBnRelu test failed for configuration: activation: "
  42. f"{use_activation}, normalization: {use_normalization}, "
  43. f"kernel: {kernel}, stride: {stride}",
  44. )
  45. def test_conv_bn_relu_with_default_torch_arguments(self):
  46. """
  47. This test check that the default arguments behavior of ConvBNRelu module is aligned with torch modules defaults.
  48. Check that behavior of ConvBNRelu doesn't change with torch package upgrades.
  49. """
  50. conv_bn_relu = ConvBNReLU(32, 32, kernel_size=1)
  51. conv_bn_relu_defaults_torch = nn.Sequential(nn.Conv2d(32, 32, kernel_size=1), nn.BatchNorm2d(32), nn.ReLU())
  52. # apply same conv weights and biases to compare output,
  53. # because conv weight and biases have random initialization.
  54. conv_bn_relu.seq[0].weight = conv_bn_relu_defaults_torch[0].weight
  55. conv_bn_relu.seq[0].bias = conv_bn_relu_defaults_torch[0].bias
  56. self.assertTrue(
  57. torch.equal(conv_bn_relu(self.sample), conv_bn_relu_defaults_torch(self.sample)),
  58. msg="ConvBnRelu test failed for defaults arguments configuration: ConvBNRelu default" "arguments are not aligned with torch defaults.",
  59. )
  60. if __name__ == "__main__":
  61. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...