mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant] Add quantize and dequantize operators to decomposition table (#93312)
Summary: This PR tries to decompose the operators in torch.ops.quantized_decomposed namespace to more primitive aten operators, this would free us from maintaining the semantics of the quantize/dequantize operators, which can be expressed more precises in terms of underlying aten operators Note: this PR just adds them to the decomposition table, we haven't enable this by default yet Test Plan: python test/test_quantization.py TestQuantizePT2E.test_q_dq_decomposition Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/93312 Approved by: https://github.com/vkuzo, https://github.com/SherlockNoMad
This commit is contained in:
parent
df13247e67
commit
782e4f5c02
|
|
@ -26,6 +26,17 @@ from torch.ao.ns.fx.utils import (
|
|||
compute_sqnr,
|
||||
)
|
||||
import copy
|
||||
from torch._decomp import get_decompositions
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
quant_decomp = get_decompositions(
|
||||
[
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
||||
]
|
||||
)
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
class TestQuantizePT2E(QuantizationTestCase):
|
||||
|
|
@ -124,7 +135,81 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
|
||||
ns.call_function(torch.ops.aten.addmm.default),
|
||||
]
|
||||
self.checkGraphModuleNodes(m, expected_node_list=node_list)
|
||||
self.checkGraphModuleNodes(
|
||||
m,
|
||||
expected_node_list=node_list,
|
||||
expected_node_occurrence=node_occurrence
|
||||
)
|
||||
|
||||
@xfailIfPython311
|
||||
def test_q_dq_decomposition(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(1, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
with override_quantized_engine("qnnpack"):
|
||||
m = M().eval()
|
||||
example_inputs = (torch.randn(1, 1, 3, 3),)
|
||||
|
||||
# program capture
|
||||
m, guards = torchdynamo.export(
|
||||
m,
|
||||
*copy.deepcopy(example_inputs),
|
||||
aten_graph=True,
|
||||
tracing_mode="real",
|
||||
)
|
||||
|
||||
qconfig = get_default_qconfig("qnnpack")
|
||||
qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Conv2d, qconfig)
|
||||
backend_config = get_qnnpack_pt2e_backend_config()
|
||||
m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
m(*example_inputs)
|
||||
node_occurrence = {
|
||||
# two for input and weight of the conv, one for output for the conv
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 3,
|
||||
}
|
||||
node_list = [
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
|
||||
ns.call_function(torch.ops.aten.convolution.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
|
||||
]
|
||||
self.checkGraphModuleNodes(
|
||||
m,
|
||||
expected_node_list=node_list,
|
||||
expected_node_occurrence=node_occurrence
|
||||
)
|
||||
m = make_fx(m, decomposition_table=quant_decomp)(*copy.deepcopy(example_inputs))
|
||||
node_occurrence = {
|
||||
# check both q/dq are decomposed
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 0,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 0,
|
||||
}
|
||||
node_list = [
|
||||
# ops in quantize
|
||||
ns.call_function(torch.ops.aten.mul.Tensor),
|
||||
ns.call_function(torch.ops.aten.round.default),
|
||||
ns.call_function(torch.ops.aten.add.Tensor),
|
||||
ns.call_function(torch.ops.aten.clamp.default),
|
||||
# ops in dequantize
|
||||
ns.call_function(torch.ops.aten.sub.Tensor),
|
||||
ns.call_function(torch.ops.aten.mul.Tensor),
|
||||
# conv op
|
||||
ns.call_function(torch.ops.aten.convolution.default),
|
||||
]
|
||||
self.checkGraphModuleNodes(
|
||||
m,
|
||||
expected_node_list=node_list,
|
||||
expected_node_occurrence=node_occurrence
|
||||
)
|
||||
|
||||
class TestQuantizePT2EModels(QuantizationTestCase):
|
||||
@skip_if_no_torchvision
|
||||
|
|
|
|||
|
|
@ -2645,6 +2645,10 @@ import torch._refs
|
|||
import torch._refs.nn.functional
|
||||
import torch._refs.special
|
||||
|
||||
_QUANTIZED_DECOMPOSED_LIB = torch.library.Library(
|
||||
"quantized_decomposed", "IMPL", "Meta"
|
||||
)
|
||||
|
||||
|
||||
def activate_meta():
|
||||
|
||||
|
|
@ -2698,6 +2702,8 @@ def activate_meta():
|
|||
_meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
|
||||
elif "mkl::" in op_overload.name():
|
||||
_meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn)
|
||||
elif "quantized_decomposed::" in op_overload.name():
|
||||
_QUANTIZED_DECOMPOSED_LIB.impl(op_overload, fn)
|
||||
else:
|
||||
_meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,31 @@ import torch
|
|||
from torch.library import Library, impl
|
||||
from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
|
||||
from typing import Tuple
|
||||
from torch._decomp import register_decomposition
|
||||
|
||||
def _quantize_per_tensor_impl(
|
||||
input: torch.Tensor,
|
||||
scale: float,
|
||||
zero_point: int,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
inv_scale = 1.0 / scale
|
||||
return torch.clamp(
|
||||
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
|
||||
).to(dtype)
|
||||
|
||||
def _dequantize_per_tensor_impl(
|
||||
input: torch.Tensor,
|
||||
scale: float,
|
||||
zero_point: int,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return (input.to(torch.float32) - zero_point) * scale
|
||||
|
||||
|
||||
|
||||
# Note: decomposed means decomposed quantized tensor, using decomposed so that the
|
||||
|
|
@ -59,8 +84,18 @@ def quantize_per_tensor(
|
|||
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
||||
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
||||
|
||||
inv_scale = 1.0 / scale
|
||||
return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype)
|
||||
return _quantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype)
|
||||
|
||||
@register_decomposition(torch.ops.quantized_decomposed.quantize_per_tensor)
|
||||
def quantize_per_tensor_decomp_impl(
|
||||
input: torch.Tensor,
|
||||
scale: float,
|
||||
zero_point: int,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return _quantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype)
|
||||
|
||||
quantized_decomposed_lib.define(
|
||||
"quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
|
||||
|
|
@ -82,15 +117,19 @@ def quantize_per_tensor_tensor(
|
|||
"""
|
||||
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
||||
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
|
||||
return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
|
||||
return _quantize_per_tensor_impl(
|
||||
input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type]
|
||||
|
||||
@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta")
|
||||
def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
|
||||
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
||||
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
|
||||
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
||||
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
||||
return torch.empty_like(input, dtype=dtype)
|
||||
@register_decomposition(torch.ops.quantized_decomposed.quantize_per_tensor.tensor)
|
||||
def quantize_per_tensor_tensor_decomp_impl(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zero_point: torch.Tensor,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return _quantize_per_tensor_impl(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type]
|
||||
|
||||
# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
|
||||
# the signature as metadata for the input Tensor, this might be useful for pattern
|
||||
|
|
@ -138,11 +177,22 @@ def dequantize_per_tensor(
|
|||
# TODO: investigate why
|
||||
# (input - zero_point).to(torch.float32) * scale
|
||||
# failed the test
|
||||
return (input.to(torch.float32) - zero_point) * scale
|
||||
return _dequantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
|
||||
|
||||
|
||||
@register_decomposition(torch.ops.quantized_decomposed.dequantize_per_tensor)
|
||||
def dequantize_per_tensor_decomp_impl(
|
||||
input: torch.Tensor,
|
||||
scale: float,
|
||||
zero_point: int,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return _dequantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype)
|
||||
|
||||
quantized_decomposed_lib.define(
|
||||
"dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
|
||||
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
|
||||
|
|
@ -163,23 +213,26 @@ def dequantize_per_tensor_tensor(
|
|||
"""
|
||||
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
||||
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
|
||||
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
|
||||
|
||||
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta")
|
||||
def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
|
||||
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
||||
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
|
||||
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
|
||||
if dtype in [torch.uint8, torch.int8, torch.int32]:
|
||||
return torch.empty_like(input, dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
|
||||
|
||||
return _dequantize_per_tensor_impl(
|
||||
input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type]
|
||||
|
||||
quantized_decomposed_lib.define(
|
||||
"choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
|
||||
"ScalarType dtype) -> (Tensor, Tensor)")
|
||||
|
||||
|
||||
@register_decomposition(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor)
|
||||
def dequantize_per_tensor_tensor_decomp_impl(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zero_point: torch.Tensor,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return _dequantize_per_tensor_impl(
|
||||
input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type]
|
||||
|
||||
@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd")
|
||||
def choose_qparams_tensor(
|
||||
input: torch.Tensor,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user