Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#367 fix: Request correct hydra-core

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:hotfix/ALG-000_hydra-req
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
  1. import torch
  2. import unittest
  3. import torch.nn as nn
  4. from super_gradients.training.utils.module_utils import ConvBNReLU
  5. class TestConvBnRelu(unittest.TestCase):
  6. def setUp(self) -> None:
  7. self.sample = torch.randn(2, 32, 64, 64)
  8. self.test_kernels = [1, 3, 5]
  9. self.test_strides = [1, 2]
  10. self.use_activation = [True, False]
  11. self.use_normalization = [True, False]
  12. self.biases = [True, False]
  13. def test_conv_bn_relu(self):
  14. for use_normalization in self.use_normalization:
  15. for use_activation in self.use_activation:
  16. for kernel in self.test_kernels:
  17. for stride in self.test_strides:
  18. for bias in self.biases:
  19. conv_bn_relu = ConvBNReLU(32, 32, kernel_size=kernel, stride=stride, padding=kernel // 2,
  20. bias=bias, use_activation=use_activation,
  21. use_normalization=use_normalization)
  22. conv_bn_relu_seq = nn.Sequential(
  23. nn.Conv2d(32, 32, kernel_size=kernel, stride=stride, padding=kernel // 2,
  24. bias=bias),
  25. nn.BatchNorm2d(32) if use_normalization else nn.Identity(),
  26. nn.ReLU() if use_activation else nn.Identity()
  27. )
  28. # apply same conv weights and biases to compare output,
  29. # because conv weight and biases have random initialization.
  30. conv_bn_relu.seq[0].weight = conv_bn_relu_seq[0].weight
  31. if bias:
  32. conv_bn_relu.seq[0].bias = conv_bn_relu_seq[0].bias
  33. self.assertTrue(torch.equal(conv_bn_relu(self.sample), conv_bn_relu_seq(self.sample)),
  34. msg=f"ConvBnRelu test failed for configuration: activation: "
  35. f"{use_activation}, normalization: {use_normalization}, "
  36. f"kernel: {kernel}, stride: {stride}")
  37. def test_conv_bn_relu_with_default_torch_arguments(self):
  38. """
  39. This test check that the default arguments behavior of ConvBNRelu module is aligned with torch modules defaults.
  40. Check that behavior of ConvBNRelu doesn't change with torch package upgrades.
  41. """
  42. conv_bn_relu = ConvBNReLU(32, 32, kernel_size=1)
  43. conv_bn_relu_defaults_torch = nn.Sequential(
  44. nn.Conv2d(32, 32, kernel_size=1),
  45. nn.BatchNorm2d(32),
  46. nn.ReLU()
  47. )
  48. # apply same conv weights and biases to compare output,
  49. # because conv weight and biases have random initialization.
  50. conv_bn_relu.seq[0].weight = conv_bn_relu_defaults_torch[0].weight
  51. conv_bn_relu.seq[0].bias = conv_bn_relu_defaults_torch[0].bias
  52. self.assertTrue(torch.equal(conv_bn_relu(self.sample), conv_bn_relu_defaults_torch(self.sample)),
  53. msg="ConvBnRelu test failed for defaults arguments configuration: ConvBNRelu default"
  54. "arguments are not aligned with torch defaults.")
  55. if __name__ == '__main__':
  56. unittest.main()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file