| import torch.nn as nn | |
| import torch | |
| def quantize(tensor, scale, zero_point, is_asym=False): | |
| if is_asym: | |
| clamp_min, clamp_max = torch.tensor(0.), torch.tensor(255.) | |
| else: | |
| clamp_min, clamp_max = torch.tensor(-128.), torch.tensor(127.) | |
| quant_tensor = torch.clamp(torch.round(tensor/scale), clamp_min, clamp_max) + zero_point | |
| return quant_tensor | |
| def dequantize(tensor, scale, zero_point): | |
| return (tensor - zero_point) * scale | |
| class QuantLinear(nn.Module): | |
| def __init__(self, quant_param): | |
| super().__init__() | |
| mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape']) | |
| self.register_buffer('mul_factor', mul_factor) | |
| self.linear = nn.Linear(128, 128) | |
| weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape']) | |
| weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape']) | |
| input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape']) | |
| input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape']) | |
| self.register_buffer('weight_scale', weight_scale) | |
| self.register_buffer('weight_zp', weight_zp) | |
| self.register_buffer('input_scale', input_scale) | |
| self.register_buffer('input_zp', input_zp) | |
| def forward(self, x): | |
| scaled_x = x * self.mul_factor | |
| quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True) | |
| quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False) | |
| dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp) | |
| dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp) | |
| out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias) | |
| return out | |
| class QuantConv2d(nn.Module): | |
| def __init__(self, quant_param): | |
| super().__init__() | |
| mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape']) | |
| self.register_buffer('mul_factor', mul_factor) | |
| self.conv2d = nn.Conv2d(128, 128, 3) | |
| weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape']) | |
| weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape']) | |
| input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape']) | |
| input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape']) | |
| self.register_buffer('weight_scale', weight_scale) | |
| self.register_buffer('weight_zp', weight_zp) | |
| self.register_buffer('input_scale', input_scale) | |
| self.register_buffer('input_zp', input_zp) | |
| def forward(self, x): | |
| scaled_x = x * self.mul_factor | |
| quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True) | |
| quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False) | |
| dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp) | |
| dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp) | |
| out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias) | |
| return out | |