mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][pt2e] Add reference representation for dynamic quantized linear (#108073)
Summary: att Test Plan: python test/test_quantization.py TestQuantizePT2E.test_representation_dynamic_linear buck2 test 'fbcode//mode/opt' fbcode//caffe2/test:quantization_pt2e -- 'test_representation_dynamic_linear' Reviewed By: kimishpatel Differential Revision: D48703076 Pull Request resolved: https://github.com/pytorch/pytorch/pull/108073 Approved by: https://github.com/andrewor14
This commit is contained in:
parent
0cfc5899f9
commit
147b3495e2
|
|
@ -552,6 +552,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
|||
quantizer: Quantizer,
|
||||
ref_node_occurrence: Dict[ns, int],
|
||||
non_ref_node_occurrence: Dict[ns, int],
|
||||
fixed_output_tol: float = None,
|
||||
output_scale_idx: int = 3,
|
||||
) -> torch.nn.Module:
|
||||
""" TODO: need to implement output checking based on output_scale once
|
||||
|
|
@ -581,17 +582,22 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
|||
self.checkGraphModuleNodes(model_copy, expected_node_occurrence=non_ref_node_occurrence)
|
||||
pt2e_quant_output_copy = model_copy(*example_inputs)
|
||||
|
||||
idx = 0
|
||||
for n in model_copy.graph.nodes:
|
||||
if n.target == torch.ops.quantized_decomposed.quantize_per_tensor.default:
|
||||
idx += 1
|
||||
if idx == output_scale_idx:
|
||||
output_scale = n.args[1]
|
||||
assert output_scale is not None
|
||||
|
||||
output_tol = None
|
||||
if fixed_output_tol is not None:
|
||||
output_tol = fixed_output_tol
|
||||
else:
|
||||
idx = 0
|
||||
for n in model_copy.graph.nodes:
|
||||
if n.target == torch.ops.quantized_decomposed.quantize_per_tensor.default:
|
||||
idx += 1
|
||||
if idx == output_scale_idx:
|
||||
output_tol = n.args[1]
|
||||
assert output_tol is not None
|
||||
|
||||
# make sure the result is off by one at most in the quantized integer representation
|
||||
self.assertTrue(
|
||||
torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output)) <= (2 * output_scale + 1e-5)
|
||||
torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output)) <= (2 * output_tol + 1e-5)
|
||||
)
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
|
|
@ -2148,6 +2154,29 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
non_ref_node_occurrence={}
|
||||
)
|
||||
|
||||
def test_representation_dynamic_linear(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
operator_config = get_symmetric_quantization_config(is_per_channel=False, is_dynamic=True)
|
||||
quantizer.set_global(operator_config)
|
||||
example_inputs = (torch.randn(2, 5),)
|
||||
|
||||
self._test_representation(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
quantizer,
|
||||
ref_node_occurrence={},
|
||||
non_ref_node_occurrence={},
|
||||
fixed_output_tol=1e-4,
|
||||
)
|
||||
|
||||
def test_representation_conv2d(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -85,6 +85,72 @@ def _reference_quantized_linear(
|
|||
return out_i8
|
||||
|
||||
|
||||
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
|
||||
torch.randn((2, 5), dtype=torch.float),
|
||||
-128,
|
||||
127,
|
||||
torch.finfo(torch.float32).eps,
|
||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
torch.zeros(1, dtype=torch.int),
|
||||
torch.tensor([-127], dtype=torch.int),
|
||||
torch.tensor([127], dtype=torch.int),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
)
|
||||
|
||||
|
||||
def _qdq_dynamic_quantized_linear(
|
||||
x_fp32, x_quant_min, x_quant_max, x_eps,
|
||||
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
|
||||
bias_fp32,
|
||||
):
|
||||
x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8)
|
||||
x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
|
||||
x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
|
||||
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
|
||||
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
|
||||
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
|
||||
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
|
||||
return out_fp32
|
||||
|
||||
def _reference_dynamic_quantized_linear(
|
||||
x_fp32, x_quant_min, x_quant_max, x_eps,
|
||||
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
|
||||
bias_fp32,
|
||||
):
|
||||
x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8)
|
||||
# decomposed representation for quantize_per_tensor
|
||||
# TODO: use out_dtype(mul, ...) here when the op is ready
|
||||
x_fp32 = x_fp32 / x_scale # fp32
|
||||
# round modes might be different here
|
||||
# pytorch is rounding to even, which is also common for most of the backends
|
||||
x_fp32 = torch.round(x_fp32) # fp32
|
||||
x_i32 = x_fp32.to(dtype=torch.int32) # int32
|
||||
x_i32 = x_i32 + x_zero_point # int32
|
||||
# clamp works for fp32, int32 and int8 dtypes
|
||||
x_i32 = torch.clamp(x_i32, x_quant_min, x_quant_max) # int32
|
||||
x_i8 = x_i32.to(dtype=torch.int8)
|
||||
|
||||
weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
|
||||
|
||||
x_i16 = x_i8.to(torch.int16)
|
||||
weight_i16 = weight_i8.to(torch.int16)
|
||||
# always set bias to None so that the same representation can work for the case
|
||||
# no matter if bias_scale == x_scale * weight_scale or not
|
||||
acc_i32 = out_dtype(
|
||||
torch.ops.aten.linear.default,
|
||||
torch.int32,
|
||||
x_i16 - x_zero_point,
|
||||
weight_i16 - weight_zero_point,
|
||||
None)
|
||||
bias_scale = x_scale * weight_scale
|
||||
bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
|
||||
acc_i32 = acc_i32 + bias_i32
|
||||
out_fp32 = acc_i32 * (x_scale * weight_scale)
|
||||
return out_fp32
|
||||
|
||||
|
||||
_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
|
||||
torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
|
||||
torch.randn(1, dtype=torch.float),
|
||||
|
|
@ -465,6 +531,27 @@ class _RewriteInfo:
|
|||
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
|
||||
|
||||
_REWRITE_INFO_LIST = [
|
||||
_RewriteInfo(
|
||||
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
||||
_qdq_dynamic_quantized_linear,
|
||||
_reference_dynamic_quantized_linear,
|
||||
partial(
|
||||
_replace_literals_with_existing_placeholders,
|
||||
literal_to_ph_idx={
|
||||
-128: 1,
|
||||
127: 2,
|
||||
torch.finfo(torch.float32).eps: 3
|
||||
}
|
||||
),
|
||||
partial(
|
||||
_replace_literals_with_existing_placeholders,
|
||||
literal_to_ph_idx={
|
||||
-128: 1,
|
||||
127: 2,
|
||||
torch.finfo(torch.float32).eps: 3
|
||||
}
|
||||
),
|
||||
),
|
||||
_RewriteInfo(
|
||||
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
|
||||
_qdq_quantized_linear,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user