diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 2760a99f390..321b5a95f74 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -2114,7 +2114,28 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): M(), example_inputs, is_per_channel=True, verify_convert=True, ) - @unittest.skip("some issues with conv2d rewrite, will fix in a separate PR") + def test_representation_linear(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 5),) + + self._test_representation( + M().eval(), + example_inputs, + quantizer, + ref_node_occurrence={}, + non_ref_node_occurrence={} + ) + def test_representation_conv2d(self): class M(torch.nn.Module): def __init__(self): diff --git a/torch/_higher_order_ops/out_dtype.py b/torch/_higher_order_ops/out_dtype.py index 5183fd0d2b7..131f7c04e07 100644 --- a/torch/_higher_order_ops/out_dtype.py +++ b/torch/_higher_order_ops/out_dtype.py @@ -22,6 +22,7 @@ from torch._higher_order_ops.utils import autograd_not_implemented # TODO to figure out a more generic approach ALLOWABLE_OPS = [ + torch.ops.aten.linear.default, torch.ops.aten.mm.default, torch.ops.aten.conv2d.default, torch.ops.aten.convolution.default, diff --git a/torch/ao/quantization/pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py index 5abc894fe58..c7c27925e54 100644 --- a/torch/ao/quantization/pt2e/representation/rewrite.py +++ b/torch/ao/quantization/pt2e/representation/rewrite.py @@ -18,6 +18,73 @@ __all__ = [ "reference_representation_rewrite", ] + +_QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (2, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randint(-128, 127, (5, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), +) + +def _qdq_quantized_linear( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, + bias_fp32, + out_scale, out_zero_point, out_quant_min, out_quant_max +): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8) + weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8) + out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8) + return out_i8 + +def _reference_quantized_linear( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, + bias_fp32, + out_scale, out_zero_point, out_quant_min, out_quant_max +): + # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. + # This results in failure to match the pattern. + # Therefore, we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max) + weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) + + x_i16 = x_i8.to(torch.int16) + weight_i16 = weight_i8.to(torch.int16) + # always set bias to None so that the same representation can work for the case + # no matter if bias_scale == x_scale * weight_scale or not + acc_i32 = out_dtype( + torch.ops.aten.linear.default, + torch.int32, + x_i16 - x_zero_point, + weight_i16 - weight_zero_point, + None) + # TODO: change to mul.Scalar + # Note: we are quantizing bias with these scales without signal from user, but it might be OK + bias_scale = x_scale * weight_scale + bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) + acc_i32 = acc_i32 + bias_i32 + # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values + acc_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, acc_i32, x_scale * weight_scale / out_scale) + out_zero_point + out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8) + return out_i8 + + _QUANTIZED_CONV2d_EXAMPLE_INPUTS = ( torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), torch.randn(1, dtype=torch.float), @@ -398,6 +465,13 @@ class _RewriteInfo: replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None _REWRITE_INFO_LIST = [ + _RewriteInfo( + _QUANTIZED_LINEAR_EXAMPLE_INPUTS, + _qdq_quantized_linear, + _reference_quantized_linear, + _replace_literals_with_new_placeholders, + _replace_literals_with_new_placeholders, + ), _RewriteInfo( _QUANTIZED_CONV2d_EXAMPLE_INPUTS, _qdq_quantized_conv2d,