[quant] Add quantize and dequantize operators to decomposition table (#93312)

Summary:
This PR tries to decompose the operators in torch.ops.quantized_decomposed namespace to more
primitive aten operators, this would free us from maintaining the semantics of the quantize/dequantize
operators, which can be expressed more precises in terms of underlying aten operators

Note: this PR just adds them to the decomposition table, we haven't enable this by default yet

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_q_dq_decomposition

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93312
Approved by: https://github.com/vkuzo, https://github.com/SherlockNoMad
This commit is contained in:
Jerry Zhang 2023-02-09 12:16:58 -08:00 committed by PyTorch MergeBot
parent df13247e67
commit 782e4f5c02
3 changed files with 168 additions and 24 deletions

View File

@ -26,6 +26,17 @@ from torch.ao.ns.fx.utils import (
compute_sqnr,
)
import copy
from torch._decomp import get_decompositions
from torch.fx.experimental.proxy_tensor import make_fx
quant_decomp = get_decompositions(
[
torch.ops.quantized_decomposed.quantize_per_tensor,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
]
)
@skipIfNoQNNPACK
class TestQuantizePT2E(QuantizationTestCase):
@ -124,7 +135,81 @@ class TestQuantizePT2E(QuantizationTestCase):
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.addmm.default),
]
self.checkGraphModuleNodes(m, expected_node_list=node_list)
self.checkGraphModuleNodes(
m,
expected_node_list=node_list,
expected_node_occurrence=node_occurrence
)
@xfailIfPython311
def test_q_dq_decomposition(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv(x)
return x
with override_quantized_engine("qnnpack"):
m = M().eval()
example_inputs = (torch.randn(1, 1, 3, 3),)
# program capture
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
qconfig = get_default_qconfig("qnnpack")
qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Conv2d, qconfig)
backend_config = get_qnnpack_pt2e_backend_config()
m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
m(*example_inputs)
m = convert_pt2e(m)
m(*example_inputs)
node_occurrence = {
# two for input and weight of the conv, one for output for the conv
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 3,
}
node_list = [
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
]
self.checkGraphModuleNodes(
m,
expected_node_list=node_list,
expected_node_occurrence=node_occurrence
)
m = make_fx(m, decomposition_table=quant_decomp)(*copy.deepcopy(example_inputs))
node_occurrence = {
# check both q/dq are decomposed
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 0,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 0,
}
node_list = [
# ops in quantize
ns.call_function(torch.ops.aten.mul.Tensor),
ns.call_function(torch.ops.aten.round.default),
ns.call_function(torch.ops.aten.add.Tensor),
ns.call_function(torch.ops.aten.clamp.default),
# ops in dequantize
ns.call_function(torch.ops.aten.sub.Tensor),
ns.call_function(torch.ops.aten.mul.Tensor),
# conv op
ns.call_function(torch.ops.aten.convolution.default),
]
self.checkGraphModuleNodes(
m,
expected_node_list=node_list,
expected_node_occurrence=node_occurrence
)
class TestQuantizePT2EModels(QuantizationTestCase):
@skip_if_no_torchvision

View File

@ -2645,6 +2645,10 @@ import torch._refs
import torch._refs.nn.functional
import torch._refs.special
_QUANTIZED_DECOMPOSED_LIB = torch.library.Library(
"quantized_decomposed", "IMPL", "Meta"
)
def activate_meta():
@ -2698,6 +2702,8 @@ def activate_meta():
_meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
elif "mkl::" in op_overload.name():
_meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn)
elif "quantized_decomposed::" in op_overload.name():
_QUANTIZED_DECOMPOSED_LIB.impl(op_overload, fn)
else:
_meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)

View File

@ -2,6 +2,31 @@ import torch
from torch.library import Library, impl
from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
from typing import Tuple
from torch._decomp import register_decomposition
def _quantize_per_tensor_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
inv_scale = 1.0 / scale
return torch.clamp(
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
).to(dtype)
def _dequantize_per_tensor_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return (input.to(torch.float32) - zero_point) * scale
# Note: decomposed means decomposed quantized tensor, using decomposed so that the
@ -59,8 +84,18 @@ def quantize_per_tensor(
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
inv_scale = 1.0 / scale
return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype)
return _quantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype)
@register_decomposition(torch.ops.quantized_decomposed.quantize_per_tensor)
def quantize_per_tensor_decomp_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return _quantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype)
quantized_decomposed_lib.define(
"quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
@ -82,15 +117,19 @@ def quantize_per_tensor_tensor(
"""
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
return _quantize_per_tensor_impl(
input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type]
@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta")
def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
return torch.empty_like(input, dtype=dtype)
@register_decomposition(torch.ops.quantized_decomposed.quantize_per_tensor.tensor)
def quantize_per_tensor_tensor_decomp_impl(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return _quantize_per_tensor_impl(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type]
# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
# the signature as metadata for the input Tensor, this might be useful for pattern
@ -138,11 +177,22 @@ def dequantize_per_tensor(
# TODO: investigate why
# (input - zero_point).to(torch.float32) * scale
# failed the test
return (input.to(torch.float32) - zero_point) * scale
return _dequantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype)
else:
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
@register_decomposition(torch.ops.quantized_decomposed.dequantize_per_tensor)
def dequantize_per_tensor_decomp_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return _dequantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype)
quantized_decomposed_lib.define(
"dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
@ -163,23 +213,26 @@ def dequantize_per_tensor_tensor(
"""
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta")
def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
if dtype in [torch.uint8, torch.int8, torch.int32]:
return torch.empty_like(input, dtype=torch.float32)
else:
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
return _dequantize_per_tensor_impl(
input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type]
quantized_decomposed_lib.define(
"choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
"ScalarType dtype) -> (Tensor, Tensor)")
@register_decomposition(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor)
def dequantize_per_tensor_tensor_decomp_impl(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return _dequantize_per_tensor_impl(
input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type]
@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd")
def choose_qparams_tensor(
input: torch.Tensor,