mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][pt2e] Add reference representation rewrite for statically quantized linear (#107994)
Summary: att Test Plan: ``` python test/test_quantization.py TestQuantizePT2E.test_representation_linear buck2 test 'fbcodemode/opt' fbcodecaffe2/test:quantization_pt2e -- 'test_representation_linear' ``` Reviewed By: kimishpatel Differential Revision: D48674862 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107994 Approved by: https://github.com/mcr229, https://github.com/guangy10
This commit is contained in:
parent
162109f6c2
commit
15d4dedbbf
|
|
@ -2114,7 +2114,28 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
M(), example_inputs, is_per_channel=True, verify_convert=True,
|
||||
)
|
||||
|
||||
@unittest.skip("some issues with conv2d rewrite, will fix in a separate PR")
|
||||
def test_representation_linear(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=False)
|
||||
quantizer.set_global(operator_config)
|
||||
example_inputs = (torch.randn(2, 5),)
|
||||
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence={},
|
||||
non_ref_node_occurrence={}
|
||||
)
|
||||
|
||||
def test_representation_conv2d(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from torch._higher_order_ops.utils import autograd_not_implemented
|
|||
|
||||
# TODO to figure out a more generic approach
|
||||
ALLOWABLE_OPS = [
|
||||
torch.ops.aten.linear.default,
|
||||
torch.ops.aten.mm.default,
|
||||
torch.ops.aten.conv2d.default,
|
||||
torch.ops.aten.convolution.default,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,73 @@ __all__ = [
|
|||
"reference_representation_rewrite",
|
||||
]
|
||||
|
||||
|
||||
_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (2, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-128], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
)
|
||||
|
||||
def _qdq_quantized_linear(
|
||||
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
|
||||
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
|
||||
bias_fp32,
|
||||
out_scale, out_zero_point, out_quant_min, out_quant_max
|
||||
):
|
||||
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
|
||||
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
|
||||
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
|
||||
out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8)
|
||||
return out_i8
|
||||
|
||||
def _reference_quantized_linear(
|
||||
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max,
|
||||
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
|
||||
bias_fp32,
|
||||
out_scale, out_zero_point, out_quant_min, out_quant_max
|
||||
):
|
||||
# without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
|
||||
# This results in failure to match the pattern.
|
||||
# Therefore, we call a torch.ops.aten.clamp here
|
||||
x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
|
||||
weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
|
||||
|
||||
x_i16 = x_i8.to(torch.int16)
|
||||
weight_i16 = weight_i8.to(torch.int16)
|
||||
# always set bias to None so that the same representation can work for the case
|
||||
# no matter if bias_scale == x_scale * weight_scale or not
|
||||
acc_i32 = out_dtype(
|
||||
torch.ops.aten.linear.default,
|
||||
torch.int32,
|
||||
x_i16 - x_zero_point,
|
||||
weight_i16 - weight_zero_point,
|
||||
None)
|
||||
# TODO: change to mul.Scalar
|
||||
# Note: we are quantizing bias with these scales without signal from user, but it might be OK
|
||||
bias_scale = x_scale * weight_scale
|
||||
bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
|
||||
acc_i32 = acc_i32 + bias_i32
|
||||
# TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
|
||||
acc_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, acc_i32, x_scale * weight_scale / out_scale) + out_zero_point
|
||||
out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
|
||||
return out_i8
|
||||
|
||||
|
||||
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
|
|
@ -398,6 +465,13 @@ class _RewriteInfo:
|
|||
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
|
||||
|
||||
_REWRITE_INFO_LIST = [
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
||||
_qdq_quantized_linear,
|
||||
_reference_quantized_linear,
|
||||
_replace_literals_with_new_placeholders,
|
||||
_replace_literals_with_new_placeholders,
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_CONV2d_EXAMPLE_INPUTS,
|
||||
_qdq_quantized_conv2d,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user