[quant][pt2e] Add reference representation rewrite for statically quantized linear (#107994)

Summary: att

Test Plan:
```
python test/test_quantization.py TestQuantizePT2E.test_representation_linear
buck2 test 'fbcodemode/opt' fbcodecaffe2/test:quantization_pt2e -- 'test_representation_linear'
```

Reviewed By: kimishpatel

Differential Revision: D48674862

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107994
Approved by: https://github.com/mcr229, https://github.com/guangy10
This commit is contained in:
Jerry Zhang 2023-08-26 15:39:52 +00:00 committed by PyTorch MergeBot
parent 162109f6c2
commit 15d4dedbbf
3 changed files with 97 additions and 1 deletions

View File

@ -2114,7 +2114,28 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
M(), example_inputs, is_per_channel=True, verify_convert=True, M(), example_inputs, is_per_channel=True, verify_convert=True,
) )
@unittest.skip("some issues with conv2d rewrite, will fix in a separate PR") def test_representation_linear(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(2, 5),)
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={}
)
def test_representation_conv2d(self): def test_representation_conv2d(self):
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -22,6 +22,7 @@ from torch._higher_order_ops.utils import autograd_not_implemented
# TODO to figure out a more generic approach # TODO to figure out a more generic approach
ALLOWABLE_OPS = [ ALLOWABLE_OPS = [
torch.ops.aten.linear.default,
torch.ops.aten.mm.default, torch.ops.aten.mm.default,
torch.ops.aten.conv2d.default, torch.ops.aten.conv2d.default,
torch.ops.aten.convolution.default, torch.ops.aten.convolution.default,

View File

@ -18,6 +18,73 @@ __all__ = [
"reference_representation_rewrite", "reference_representation_rewrite",
] ]
_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (2, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-127], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
torch.randn(1, dtype=torch.float),
torch.randn(1, dtype=torch.float),
torch.zeros(1, dtype=torch.int),
torch.tensor([-128], dtype=torch.int),
torch.tensor([127], dtype=torch.int),
)
def _qdq_quantized_linear(
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
bias_fp32,
out_scale, out_zero_point, out_quant_min, out_quant_max
):
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
return out_i8
def _reference_quantized_linear(
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
bias_fp32,
out_scale, out_zero_point, out_quant_min, out_quant_max
):
# without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
# This results in failure to match the pattern.
# Therefore, we call a torch.ops.aten.clamp here
x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
x_i16 = x_i8.to(torch.int16)
weight_i16 = weight_i8.to(torch.int16)
# always set bias to None so that the same representation can work for the case
# no matter if bias_scale == x_scale * weight_scale or not
acc_i32 = out_dtype(
torch.ops.aten.linear.default,
torch.int32,
x_i16 - x_zero_point,
weight_i16 - weight_zero_point,
None)
# TODO: change to mul.Scalar
# Note: we are quantizing bias with these scales without signal from user, but it might be OK
bias_scale = x_scale * weight_scale
bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
acc_i32 = acc_i32 + bias_i32
# TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
acc_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, acc_i32, x_scale * weight_scale / out_scale) + out_zero_point
out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
return out_i8
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = ( _QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
torch.randn(1, dtype=torch.float), torch.randn(1, dtype=torch.float),
@ -398,6 +465,13 @@ class _RewriteInfo:
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
_REWRITE_INFO_LIST = [ _REWRITE_INFO_LIST = [
_RewriteInfo(
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
_qdq_quantized_linear,
_reference_quantized_linear,
_replace_literals_with_new_placeholders,
_replace_literals_with_new_placeholders,
),
_RewriteInfo( _RewriteInfo(
_QUANTIZED_CONV2d_EXAMPLE_INPUTS, _QUANTIZED_CONV2d_EXAMPLE_INPUTS,
_qdq_quantized_conv2d, _qdq_quantized_conv2d,