Modify signature of dequantize ops for decomposed quantized Tensor (#119173) (#121450)

Summary:
X-link: https://github.com/pytorch/executorch/pull/2308

Note: The initial purpose of this PR is to draw suggestion and feedback regarding better alternative, if any.

At present, dequantize op for decomposed quantized Tensor representation e.g. dequantize_per_tensor() assumes the output dtype as torch.float and hence, it does not have the output dtype in its operator argument list. However, this op signature becomes unusable when the assumption breaks. Because, in case the output dtype is different from torch.float, there is no way to specify the same during dequantization.

This change is aimed at generalizing the signature of dequantize op like dequantize_per_tensor() for wider use-cases where the output dtype can be different from torch.float and needs to passed during dequantization. The proposal is to use an additional argument named 'output_dtype' to solve the problem. However, we would also like to have suggestion and feedback regarding any better alternative that can be used instead.

cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen leslie-fang-intel

Reviewed By: digantdesai

Differential Revision: D53590486

Pulled By: manuelcandales

Co-authored-by: kausik <kmaiti@habana.ai>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121450
Approved by: https://github.com/jerryzh168
This commit is contained in:
kausik 2024-03-12 12:36:31 +00:00 committed by PyTorch MergeBot
parent 06d2392003
commit edf22f3a48

View File

@ -1,9 +1,9 @@
import torch
from torch.library import Library, impl
from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
from typing import Tuple
from torch._refs import _unsqueeze_multiple
from typing import Optional, Tuple
import torch
from torch._refs import _unsqueeze_multiple
from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
from torch.library import impl, Library
# Note: decomposed means decomposed quantized tensor, using decomposed so that the
# name is not too long
@ -13,7 +13,7 @@ _DTYPE_TO_QVALUE_BOUNDS = {
torch.uint8: (0, 255),
torch.int8: (-128, 127),
torch.int16: (-(2**15), 2**15 - 1),
torch.int32: (-(2**31), 2**31 - 1)
torch.int32: (-(2**31), 2**31 - 1),
}
# Helper to check the passed in quant min and max are valid for the dtype
@ -60,13 +60,26 @@ def quantize_per_tensor(
"""
if input.dtype == torch.bfloat16:
input = input.to(torch.float32)
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
inv_scale = 1.0 / scale
return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype)
@impl(quantized_decomposed_lib, "quantize_per_tensor", "Meta")
def quantize_per_tensor_meta(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype
) -> torch.Tensor:
if input.dtype == torch.bfloat16:
input = input.to(torch.float32)
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
return torch.empty_like(input, dtype=dtype)
quantized_decomposed_lib.define(
"quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
@ -90,7 +103,14 @@ def quantize_per_tensor_tensor(
return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta")
def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
def quantize_per_tensor_tensor_meta(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype
) -> torch.Tensor:
if input.dtype == torch.bfloat16:
input = input.to(torch.float32)
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
@ -122,7 +142,14 @@ def quantize_per_tensor_tensor2(
return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype)
@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "Meta")
def quantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_max, dtype):
def quantize_per_tensor_tensor2_meta(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: torch.Tensor,
quant_max: torch.Tensor,
dtype: torch.dtype
) -> torch.Tensor:
return quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype)
# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
@ -131,7 +158,7 @@ def quantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_
# We will revisit this later if we found there are no use cases for it
quantized_decomposed_lib.define(
"dequantize_per_tensor(Tensor input, float scale, int zero_point, "
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
"int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
@impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd")
def dequantize_per_tensor(
@ -140,7 +167,9 @@ def dequantize_per_tensor(
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
""" Affine dequantization for the Tensor using the same quantization parameters to map
from quantized values to floating point values
@ -163,22 +192,40 @@ def dequantize_per_tensor(
dtype (torch.dtype): dtype for input Tensor (not used in computation,
reserved for pattern matching)
out_dtype (torch.dtype?): optional dtype for output Tensor
Returns:
dequantized float32 Tensor
"""
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
if out_dtype is None:
out_dtype = torch.float32
if dtype in _DTYPE_TO_QVALUE_BOUNDS:
# TODO: investigate why
# (input - zero_point).to(torch.float32) * scale
# failed the test
return (input.to(torch.float32) - zero_point) * scale
return (input.to(out_dtype) - zero_point) * scale
else:
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
@impl(quantized_decomposed_lib, "dequantize_per_tensor", "Meta")
def dequantize_per_tensor_meta(
input: torch.Tensor,
scale: torch.Tensor,
zero_pointe: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
if out_dtype is None:
out_dtype = torch.float32
return torch.empty_like(input, dtype=out_dtype)
quantized_decomposed_lib.define(
"dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
"int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "CompositeExplicitAutograd")
def dequantize_per_tensor_tensor(
@ -187,7 +234,9 @@ def dequantize_per_tensor_tensor(
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
""" Affine dequantization for the Tensor using the same quantization parameters to map
from quantized values to floating point values
@ -196,22 +245,33 @@ def dequantize_per_tensor_tensor(
"""
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype, out_dtype=out_dtype)
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta")
def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
def dequantize_per_tensor_tensor_meta(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
if out_dtype is None:
out_dtype = torch.float32
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
if dtype in _DTYPE_TO_QVALUE_BOUNDS:
return torch.empty_like(input, dtype=torch.float32)
return torch.empty_like(input, dtype=out_dtype)
else:
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
# TODO: remove other variants and keep this one
quantized_decomposed_lib.define(
"dequantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, "
"Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor")
"Tensor quant_min, Tensor quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "CompositeExplicitAutograd")
def dequantize_per_tensor_tensor2(
@ -220,7 +280,9 @@ def dequantize_per_tensor_tensor2(
zero_point: torch.Tensor,
quant_min: torch.Tensor,
quant_max: torch.Tensor,
dtype: torch.dtype
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
""" Affine dequantization for the Tensor using the same quantization parameters to map
from quantized values to floating point values
@ -229,11 +291,21 @@ def dequantize_per_tensor_tensor2(
"""
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype)
return dequantize_per_tensor(
input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype, out_dtype=out_dtype)
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "Meta")
def dequantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_max, dtype):
return dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype)
def dequantize_per_tensor_tensor2_meta(
input,
scale,
zero_point,
quant_min,
quant_max,
dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
return dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype, out_dtype=out_dtype)
quantized_decomposed_lib.define(
"choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
@ -415,7 +487,7 @@ def quantize_per_channel_meta(
# We will revisit this later if we found there are no use cases for it
quantized_decomposed_lib.define(
"dequantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
"int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd")
def dequantize_per_channel(
@ -425,7 +497,9 @@ def dequantize_per_channel(
axis: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
""" Affine per channel dequantization for the Tensor using the same quantization
parameters for each channel/axis to map from quantized values to floating point values
@ -450,20 +524,24 @@ def dequantize_per_channel(
dtype (torch.dtype): requested dtype for output Tensor (not used in computation,
reserved for pattern matching)
out_dtype (torch.dtype?): optional dtype for output Tensor
Returns:
dequantized float32 Tensor
"""
assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
if out_dtype is None:
out_dtype = torch.float32
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
input, permute_axis_list = _permute_to_axis_zero(input, axis)
res = torch.zeros_like(input, dtype=torch.float32)
res = torch.zeros_like(input, dtype=out_dtype)
for i in range(input.size(0)):
# TODO: investigate why
# (input[i] - zero_points[i]).to(torch.float32) * scales[i]
# (input[i] - zero_points[i]).to(out_dtype) * scales[i]
# failed the test
res[i] = (input[i].to(torch.float32) - zero_points[i]) * scales[i]
res[i] = (input[i].to(out_dtype) - zero_points[i]) * scales[i]
out = res.permute(tuple(permute_axis_list))
return out
@ -476,12 +554,16 @@ def dequantize_per_channel_meta(
axis: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None
) -> torch.Tensor:
assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
if out_dtype is None:
out_dtype = torch.float32
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
return torch.empty_like(input, dtype=torch.float32)
return torch.empty_like(input, dtype=out_dtype)
quantized_decomposed_lib.define(
"fake_quant_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "