1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
|
- import torch
- from pytorch_quantization import nn as quant_nn
- from torch import nn
- from super_gradients.training.dataloaders import cifar10_train
- from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator
- from super_gradients.training.utils.quantization.core import SGQuantMixin
- from super_gradients.training.utils.quantization.export import export_quantized_module_to_onnx
- from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer
- def e2e_example():
- class MyBlock(nn.Module):
- def __init__(self, in_feats, out_feats) -> None:
- super().__init__()
- self.in_feats = in_feats
- self.out_feats = out_feats
- self.flatten = nn.Flatten()
- self.linear = nn.Linear(in_feats, out_feats)
- def forward(self, x):
- return self.linear(self.flatten(x))
- class MyQuantizedBlock(SGQuantMixin):
- def __init__(self, in_feats, out_feats) -> None:
- super().__init__()
- self.flatten = nn.Flatten()
- self.linear = quant_nn.QuantLinear(in_feats, out_feats)
- def forward(self, x):
- return self.linear(self.flatten(x))
- class MyModel(nn.Module):
- def __init__(self, res, n_classes) -> None:
- super().__init__()
- self.my_block = MyBlock(3 * (res**2), n_classes)
- def forward(self, x):
- return self.my_block(x)
- res = 32
- n_clss = 10
- module = MyModel(res, n_clss)
- # QUANTIZE
- q_util = SelectiveQuantizer()
- q_util.register_quantization_mapping(layer_names={"my_block"}, quantized_target_class=MyQuantizedBlock)
- q_util.quantize_module(module)
- # CALIBRATE (PTQ)
- train_loader = cifar10_train()
- calib = QuantizationCalibrator()
- calib.calibrate_model(module, method=q_util.default_quant_modules_calib_method_inputs, calib_data_loader=train_loader)
- module.cuda()
- # SANITY
- x = torch.rand(1, 3, res, res, device="cuda")
- with torch.no_grad():
- y = module(x)
- torch.testing.assert_close(y.size(), (1, n_clss))
- print(module)
- # EXPORT TO ONNX
- export_quantized_module_to_onnx(module, "my_quantized_model.onnx", input_shape=(1, 3, res, res))
- if __name__ == "__main__":
- e2e_example()
|