mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: X-link: https://github.com/pytorch/executorch/pull/2308 Note: The initial purpose of this PR is to draw suggestion and feedback regarding better alternative, if any. At present, dequantize op for decomposed quantized Tensor representation e.g. dequantize_per_tensor() assumes the output dtype as torch.float and hence, it does not have the output dtype in its operator argument list. However, this op signature becomes unusable when the assumption breaks. Because, in case the output dtype is different from torch.float, there is no way to specify the same during dequantization. This change is aimed at generalizing the signature of dequantize op like dequantize_per_tensor() for wider use-cases where the output dtype can be different from torch.float and needs to passed during dequantization. The proposal is to use an additional argument named 'output_dtype' to solve the problem. However, we would also like to have suggestion and feedback regarding any better alternative that can be used instead. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen leslie-fang-intel Reviewed By: digantdesai Differential Revision: D53590486 Pulled By: manuelcandales Co-authored-by: kausik <kmaiti@habana.ai> Pull Request resolved: https://github.com/pytorch/pytorch/pull/121450 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
06d2392003
commit
edf22f3a48
|
|
@ -1,9 +1,9 @@
|
|||
import torch
|
||||
from torch.library import Library, impl
|
||||
from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
|
||||
from typing import Tuple
|
||||
from torch._refs import _unsqueeze_multiple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch._refs import _unsqueeze_multiple
|
||||
from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
|
||||
from torch.library import impl, Library
|
||||
|
||||
# Note: decomposed means decomposed quantized tensor, using decomposed so that the
|
||||
# name is not too long
|
||||
|
|
@ -13,7 +13,7 @@ _DTYPE_TO_QVALUE_BOUNDS = {
|
|||
torch.uint8: (0, 255),
|
||||
torch.int8: (-128, 127),
|
||||
torch.int16: (-(2**15), 2**15 - 1),
|
||||
torch.int32: (-(2**31), 2**31 - 1)
|
||||
torch.int32: (-(2**31), 2**31 - 1),
|
||||
}
|
||||
|
||||
# Helper to check the passed in quant min and max are valid for the dtype
|
||||
|
|
@ -60,13 +60,26 @@ def quantize_per_tensor(
|
|||
"""
|
||||
if input.dtype == torch.bfloat16:
|
||||
input = input.to(torch.float32)
|
||||
|
||||
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
||||
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
||||
|
||||
inv_scale = 1.0 / scale
|
||||
return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype)
|
||||
|
||||
@impl(quantized_decomposed_lib, "quantize_per_tensor", "Meta")
|
||||
def quantize_per_tensor_meta(
|
||||
input: torch.Tensor,
|
||||
scale: float,
|
||||
zero_point: int,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
if input.dtype == torch.bfloat16:
|
||||
input = input.to(torch.float32)
|
||||
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
||||
return torch.empty_like(input, dtype=dtype)
|
||||
|
||||
quantized_decomposed_lib.define(
|
||||
"quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
|
||||
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
|
||||
|
|
@ -90,7 +103,14 @@ def quantize_per_tensor_tensor(
|
|||
return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
|
||||
|
||||
@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta")
|
||||
def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
|
||||
def quantize_per_tensor_tensor_meta(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zero_point: torch.Tensor,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
if input.dtype == torch.bfloat16:
|
||||
input = input.to(torch.float32)
|
||||
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
||||
|
|
@ -122,7 +142,14 @@ def quantize_per_tensor_tensor2(
|
|||
return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype)
|
||||
|
||||
@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "Meta")
|
||||
def quantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_max, dtype):
|
||||
def quantize_per_tensor_tensor2_meta(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zero_point: torch.Tensor,
|
||||
quant_min: torch.Tensor,
|
||||
quant_max: torch.Tensor,
|
||||
dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
return quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype)
|
||||
|
||||
# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
|
||||
|
|
@ -131,7 +158,7 @@ def quantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_
|
|||
# We will revisit this later if we found there are no use cases for it
|
||||
quantized_decomposed_lib.define(
|
||||
"dequantize_per_tensor(Tensor input, float scale, int zero_point, "
|
||||
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
|
||||
"int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
|
||||
|
||||
@impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd")
|
||||
def dequantize_per_tensor(
|
||||
|
|
@ -140,7 +167,9 @@ def dequantize_per_tensor(
|
|||
zero_point: int,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
out_dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
""" Affine dequantization for the Tensor using the same quantization parameters to map
|
||||
from quantized values to floating point values
|
||||
|
|
@ -163,22 +192,40 @@ def dequantize_per_tensor(
|
|||
dtype (torch.dtype): dtype for input Tensor (not used in computation,
|
||||
reserved for pattern matching)
|
||||
|
||||
out_dtype (torch.dtype?): optional dtype for output Tensor
|
||||
|
||||
Returns:
|
||||
dequantized float32 Tensor
|
||||
"""
|
||||
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
|
||||
if out_dtype is None:
|
||||
out_dtype = torch.float32
|
||||
if dtype in _DTYPE_TO_QVALUE_BOUNDS:
|
||||
# TODO: investigate why
|
||||
# (input - zero_point).to(torch.float32) * scale
|
||||
# failed the test
|
||||
return (input.to(torch.float32) - zero_point) * scale
|
||||
return (input.to(out_dtype) - zero_point) * scale
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
|
||||
|
||||
@impl(quantized_decomposed_lib, "dequantize_per_tensor", "Meta")
|
||||
def dequantize_per_tensor_meta(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zero_pointe: torch.Tensor,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
out_dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
if out_dtype is None:
|
||||
out_dtype = torch.float32
|
||||
return torch.empty_like(input, dtype=out_dtype)
|
||||
|
||||
quantized_decomposed_lib.define(
|
||||
"dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
|
||||
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
|
||||
"int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
|
||||
|
||||
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "CompositeExplicitAutograd")
|
||||
def dequantize_per_tensor_tensor(
|
||||
|
|
@ -187,7 +234,9 @@ def dequantize_per_tensor_tensor(
|
|||
zero_point: torch.Tensor,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
out_dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
""" Affine dequantization for the Tensor using the same quantization parameters to map
|
||||
from quantized values to floating point values
|
||||
|
|
@ -196,22 +245,33 @@ def dequantize_per_tensor_tensor(
|
|||
"""
|
||||
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
||||
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
|
||||
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
|
||||
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype, out_dtype=out_dtype)
|
||||
|
||||
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta")
|
||||
def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
|
||||
def dequantize_per_tensor_tensor_meta(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zero_point: torch.Tensor,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
out_dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
if out_dtype is None:
|
||||
out_dtype = torch.float32
|
||||
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
||||
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
|
||||
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
|
||||
if dtype in _DTYPE_TO_QVALUE_BOUNDS:
|
||||
return torch.empty_like(input, dtype=torch.float32)
|
||||
return torch.empty_like(input, dtype=out_dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
|
||||
|
||||
# TODO: remove other variants and keep this one
|
||||
quantized_decomposed_lib.define(
|
||||
"dequantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, "
|
||||
"Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor")
|
||||
"Tensor quant_min, Tensor quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
|
||||
|
||||
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "CompositeExplicitAutograd")
|
||||
def dequantize_per_tensor_tensor2(
|
||||
|
|
@ -220,7 +280,9 @@ def dequantize_per_tensor_tensor2(
|
|||
zero_point: torch.Tensor,
|
||||
quant_min: torch.Tensor,
|
||||
quant_max: torch.Tensor,
|
||||
dtype: torch.dtype
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
out_dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
""" Affine dequantization for the Tensor using the same quantization parameters to map
|
||||
from quantized values to floating point values
|
||||
|
|
@ -229,11 +291,21 @@ def dequantize_per_tensor_tensor2(
|
|||
"""
|
||||
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
||||
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
|
||||
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype)
|
||||
return dequantize_per_tensor(
|
||||
input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype, out_dtype=out_dtype)
|
||||
|
||||
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "Meta")
|
||||
def dequantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_max, dtype):
|
||||
return dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype)
|
||||
def dequantize_per_tensor_tensor2_meta(
|
||||
input,
|
||||
scale,
|
||||
zero_point,
|
||||
quant_min,
|
||||
quant_max,
|
||||
dtype,
|
||||
*,
|
||||
out_dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
return dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype, out_dtype=out_dtype)
|
||||
|
||||
quantized_decomposed_lib.define(
|
||||
"choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
|
||||
|
|
@ -415,7 +487,7 @@ def quantize_per_channel_meta(
|
|||
# We will revisit this later if we found there are no use cases for it
|
||||
quantized_decomposed_lib.define(
|
||||
"dequantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
|
||||
"int quant_min, int quant_max, ScalarType dtype) -> Tensor")
|
||||
"int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor")
|
||||
|
||||
@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd")
|
||||
def dequantize_per_channel(
|
||||
|
|
@ -425,7 +497,9 @@ def dequantize_per_channel(
|
|||
axis: int,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
out_dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
""" Affine per channel dequantization for the Tensor using the same quantization
|
||||
parameters for each channel/axis to map from quantized values to floating point values
|
||||
|
|
@ -450,20 +524,24 @@ def dequantize_per_channel(
|
|||
dtype (torch.dtype): requested dtype for output Tensor (not used in computation,
|
||||
reserved for pattern matching)
|
||||
|
||||
out_dtype (torch.dtype?): optional dtype for output Tensor
|
||||
|
||||
Returns:
|
||||
dequantized float32 Tensor
|
||||
"""
|
||||
assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
|
||||
if out_dtype is None:
|
||||
out_dtype = torch.float32
|
||||
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
|
||||
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
||||
input, permute_axis_list = _permute_to_axis_zero(input, axis)
|
||||
res = torch.zeros_like(input, dtype=torch.float32)
|
||||
res = torch.zeros_like(input, dtype=out_dtype)
|
||||
|
||||
for i in range(input.size(0)):
|
||||
# TODO: investigate why
|
||||
# (input[i] - zero_points[i]).to(torch.float32) * scales[i]
|
||||
# (input[i] - zero_points[i]).to(out_dtype) * scales[i]
|
||||
# failed the test
|
||||
res[i] = (input[i].to(torch.float32) - zero_points[i]) * scales[i]
|
||||
res[i] = (input[i].to(out_dtype) - zero_points[i]) * scales[i]
|
||||
|
||||
out = res.permute(tuple(permute_axis_list))
|
||||
return out
|
||||
|
|
@ -476,12 +554,16 @@ def dequantize_per_channel_meta(
|
|||
axis: int,
|
||||
quant_min: int,
|
||||
quant_max: int,
|
||||
dtype: torch.dtype
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
out_dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
|
||||
if out_dtype is None:
|
||||
out_dtype = torch.float32
|
||||
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
|
||||
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
||||
return torch.empty_like(input, dtype=torch.float32)
|
||||
return torch.empty_like(input, dtype=out_dtype)
|
||||
|
||||
quantized_decomposed_lib.define(
|
||||
"fake_quant_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user