diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst index a15e901bfd5..4e4ce90c605 100644 --- a/docs/source/quantization-support.rst +++ b/docs/source/quantization-support.rst @@ -250,6 +250,18 @@ the values observed during calibration (PTQ) or training (QAT). default_per_channel_weight_observer default_dynamic_quant_observer default_float_qparams_observer + AffineQuantizedObserverBase + Granularity + MappingType + PerAxis + PerBlock + PerGroup + PerRow + PerTensor + PerToken + TorchAODType + ZeroPointDomain + get_block_size torch.ao.quantization.fake_quantize ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/mypy.ini b/mypy.ini index 5ab02361d61..65f9ee43a6b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -79,6 +79,9 @@ ignore_missing_imports = True [mypy-torch.ao.quantization.experimental.fake_quantize] ignore_missing_imports = True +[mypy-torch.ao.quantization.pt2e._affine_quantization] +ignore_errors = True + # # Files with various errors. Mostly real errors, possibly some false # positives as well. diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index a51614a32e1..91be2f8e485 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -42,7 +42,6 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( ) from torch.export import export_for_training from torch.fx import Node -from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_quantization import ( NodeSpec as ns, PT2EQuantizationTestCase, @@ -1865,6 +1864,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): torch.ops.aten.batch_norm.default, ) + @parametrize( + "device", + ["cpu"] + (["cuda"] if TEST_CUDA else []) + (["hpu"] if TEST_HPU else []), + ) def test_move_exported_model_bn(self, device): """ Test switching batch_norm behavior between train and eval modes using @@ -2477,9 +2480,90 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): check_nn_module(node) -instantiate_parametrized_tests(TestQuantizePT2E) +@skipIfNoQNNPACK +class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase): + def test_channel_group_quantization(self): + from torch.ao.quantization.observer import MappingType, PerGroup, PerToken + from torch.ao.quantization.pt2e._affine_quantization import ( + AffineQuantizedMinMaxObserver, + ) -devices = ["cpu", "cuda"] -if TEST_HPU: - devices.append("hpu") -instantiate_device_type_tests(TestQuantizePT2E, globals(), only_for=devices) + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.linear.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=None, + is_dynamic=False, + observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args( + # TODO: maybe align the arg name here + target_dtype=torch.uint8, + mapping_type=MappingType.SYMMETRIC, + granularity=PerToken(), + ), + ) + + weight_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=None, + is_dynamic=False, + observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args( + target_dtype=torch.uint8, + mapping_type=MappingType.SYMMETRIC, + granularity=PerGroup(group_size=128), + ), + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + }, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 20) + + def forward(self, x): + return self.linear(x) + + node_occurrence = { + torch.ops.quant.quantize_affine: 2, + torch.ops.quant.dequantize_affine: 2, + } + node_list = [ + torch.ops.quant.quantize_affine, + torch.ops.quant.dequantize_affine, + torch.ops.quant.quantize_affine, + torch.ops.quant.dequantize_affine, + ] + example_inputs = (torch.randn(5, 128),) + self._test_quantizer( + M().eval(), + example_inputs, + BackendAQuantizer(), + node_occurrence, + node_list, + is_debug_mode=True, + ) + + +instantiate_parametrized_tests(TestQuantizePT2E) diff --git a/test/test_quantization.py b/test/test_quantization.py index b7a876bcdc3..61a8e310c7a 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -87,6 +87,7 @@ try: from quantization.pt2e.test_metadata_porting import TestMetaDataPorting # noqa: F401 from quantization.pt2e.test_numeric_debugger import TestNumericDebugger # noqa: F401 from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2E # noqa: F401 + from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2EAffineQuantization # noqa: F401 from quantization.pt2e.test_representation import TestPT2ERepresentation # noqa: F401 from quantization.pt2e.test_xnnpack_quantizer import TestXNNPACKQuantizer # noqa: F401 from quantization.pt2e.test_xnnpack_quantizer import TestXNNPACKQuantizerModels # noqa: F401 diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index 7675741ea42..99dcac1e23f 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -168,6 +168,20 @@ __all__ = [ "prepare_for_propagation_comparison", "extract_results_from_loggers", "compare_results", + # from torchao, should be merged with torchao + # in the future + "AffineQuantizedObserverBase", + "Granularity", + "MappingType", + "PerAxis", + "PerBlock", + "PerGroup", + "PerRow", + "PerTensor", + "PerToken", + "TorchAODType", + "ZeroPointDomain", + "get_block_size", ] diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index 27317b645f6..ad43020d713 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -1,5 +1,8 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs +# temporarily skip RUF for this file for now, we can re-enable +# after move the affine quantization related things to torchao +# noqa: RUF """ This module implements observers which are used to collect statistics about the values observed during calibration (PTQ) or training (QAT). @@ -54,6 +57,18 @@ __all__ = [ "RecordingObserver", "ReuseInputObserver", "UniformQuantizationObserverBase", + "AffineQuantizedObserverBase", + "Granularity", + "MappingType", + "PerAxis", + "PerBlock", + "PerGroup", + "PerRow", + "PerTensor", + "PerToken", + "TorchAODType", + "ZeroPointDomain", + "get_block_size", ] @@ -1584,6 +1599,258 @@ class ReuseInputObserver(ObserverBase): ) +""" +# Experimental Affine Quantization Feature START +We plan to merge the following with torchao repo after we move pt2e flow to torchao +copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py +""" +from dataclasses import dataclass +from enum import auto, Enum + + +class MappingType(Enum): + """How floating point number is mapped to integer number + + symmetric mapping means floating point range is symmetrically mapped to integer range + let's say we have floating point range (-3.5, 10.2) and integer range (-8, 7) (int4) + we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7) + e.g. scale = (10.2 - (-10.2)) / (7 - (-8)) + + SYMMETRIC_NO_CLIPPING_ERR is a variant of symmetric mapping, where the scale is the max of smin + and smax, where smin = min_val_neg / quant_min, and smax = max_val_pos / quant_max. By calculating + smin and smax individually, there can be less round error on negative values, and no out-of-range + of all floating point values. + + asymmetric mapping means we just directly map the floating point range to integer range, + for the above example, we will map (-3.5, 10.2) to (-8, 7) and calculate quantization parameter + based on this mapping + e.g. scale = (10.2 - (-3.5)) / (7 - (-8)) + """ + + SYMMETRIC = auto() + SYMMETRIC_NO_CLIPPING_ERR = auto() + ASYMMETRIC = auto() + + +class ZeroPointDomain(Enum): + """Enum that indicate whether zero_point is in integer domain or floating point domain + + integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer) + float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale + none domain: quantized_val = (float_val / scale) + """ + + INT = auto() + FLOAT = auto() + NONE = auto() + + +class TorchAODType(Enum): + """ + Placeholder for dtypes that do not exist in PyTorch core yet. + """ + + # torch.int1 to torch.int7 will be added to PyTorch 2.6 + # These will remain here for BC with older PyTorch versions + INT1 = auto() + INT2 = auto() + INT3 = auto() + INT4 = auto() + INT5 = auto() + INT6 = auto() + INT7 = auto() + + +@dataclass(frozen=True) +class Granularity: + """ + Base class for representing the granularity of quantization. + + This class serves as a parent for specific granularity types used in + quantization operations, such as per-tensor or per-axis quantization. + """ + + +@dataclass(frozen=True) +class PerBlock(Granularity): + """ + Represents per-block granularity in quantization. See + :func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for + `block_size` + + Attributes: + block_size (Tuple[int, ...]): The size of each quantization group + """ + + block_size: Tuple[int, ...] + + +@dataclass(frozen=True) +class PerTensor(Granularity): + """ + Represents per-tensor granularity in quantization. + + This granularity type calculates the quantization parameters + based off the entire tensor. + + """ + + +@dataclass(frozen=True) +class PerAxis(Granularity): + """ + Represents per-axis granularity in quantization. + + This granularity type calculates different quantization parameters + along a specified axis of the tensor. + + For example if the input tensor is shape [8, 16] and axis=0, then + the quantization parameters are calculated for each row of the tensor. + Giving a total of 8 quantization parameters. + + Attributes: + axis (int): The axis along which reduction is performed. + """ + + axis: int + + +@dataclass(frozen=True) +class PerGroup(Granularity): + """ + Represents per-channel group granularity in quantization. + + This granularity type calculates different quantization parameters + for each group of elements. + + For example if the input tensor is shape [8, 16], and the group size is 4, then + the input tensor is reshaped to [64, 4] + quantization parameters are calculated for each group of 4 elements, + giving a total of 64 quantization parameters. + + Attributes: + group_size (int): The size of each quantization group + + """ + + group_size: int + + +class PerRow(Granularity): + """ + Represents row-wise granularity in quantization. + + This is a special case of per-axis quantization and is unique to Float8 matmuls + where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight + is quantized with a block_size of (1, weight.shape[1]). + """ + + +class PerToken(Granularity): + """ + Represents per-token granularity in quantization. + + This granularity type calculates a different set of quantization parameters + for each token, which is represented as the last dimension of the tensor. + + For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens + with 4 elements each, and we will calculate 6 sets of quantization parameters, + one for each token. + + If the input tensor has only two dimensions, e.g. [8, 16], then this is + equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters. + """ + + +def get_block_size( + input_shape: Tuple[int, ...], granularity: Granularity +) -> Tuple[int, ...]: + """Get the block size based on the input shape and granularity type. + + Args: + input_shape: The input tensor shape possibly more than 2 dimensions + granularity: The granularity type of the quantization + """ + assert isinstance( + granularity, Granularity + ), "Please provide an instance of Granularity, not subclass of it" + if isinstance(granularity, PerTensor): + return input_shape + elif isinstance(granularity, PerAxis): + block_size = list(input_shape) + block_size[granularity.axis] = 1 + return tuple(block_size) + elif isinstance(granularity, PerRow): + return (1,) * (len(input_shape) - 1) + (input_shape[-1],) + elif isinstance(granularity, PerGroup): + assert ( + len(input_shape) == 2 + ), f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" + return (1, granularity.group_size) + elif isinstance(granularity, PerToken): + block_size = list(input_shape) + block_size[-1] = input_shape[-1] + return tuple(block_size) + raise ValueError(f"Unsupported Granularity: {granularity}") + + +class AffineQuantizedObserverBase(ABC, torch.nn.Module): + """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) + + Args: + `granularity` and `block_size`: The granularity of the quantization, + must specify at least one, if both are specified `block_size` takes precedence + Current supported granularity type are `PerTensor` and `PerAxis` + other args: please see `:class:torchao.dtypes.AffineQuantizedTensor` + """ + + with_args = classmethod(_with_args) + + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + # there could be some extra args that's ignored + **kwargs, + ): + super().__init__() + assert granularity is not None, "granularity is None" + + self.mapping_type = mapping_type + self.target_dtype = target_dtype + self.granularity = granularity + self.quant_min = quant_min + self.quant_max = quant_max + self.eps = eps + self.scale_dtype = scale_dtype + self.zero_point_dtype = zero_point_dtype + self.preserve_zero = preserve_zero + self.zero_point_domain = zero_point_domain + # populatd during forward + self.block_size = None + self.original_dtype = None + + @abstractmethod + def forward(self, input: torch.Tensor) -> torch.Tensor: + """forward function should take the input tensor + and updates internal stats and return the original input Tensor + """ + + @abstractmethod + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculate quantization parameter based on the stats attached to the observer module + and returns a tuple of scale and zero_point Tensor + """ + + def _is_observer_script_module(mod, obs_type_name): """Returns true if given mod is an instance of Observer script module.""" if isinstance(mod, torch.jit.RecursiveScriptModule): @@ -1594,10 +1861,17 @@ def _is_observer_script_module(mod, obs_type_name): return False +# Experimental Affine Quantization Feature END + + def _is_activation_post_process(module): return isinstance( module, - (torch.ao.quantization.ObserverBase, torch.ao.quantization.FakeQuantizeBase), + ( + torch.ao.quantization.ObserverBase, + torch.ao.quantization.FakeQuantizeBase, + AffineQuantizedObserverBase, + ), ) or _is_observer_script_module(module, "quantization.observer") diff --git a/torch/ao/quantization/pt2e/_affine_quantization.py b/torch/ao/quantization/pt2e/_affine_quantization.py new file mode 100644 index 00000000000..4d2e2b0ba41 --- /dev/null +++ b/torch/ao/quantization/pt2e/_affine_quantization.py @@ -0,0 +1,775 @@ +# copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py +# and https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py +# PLESE DON'T MODIFY THIS FILE SO THAT WE DON'T GET OUT OF SYNC +import logging +from abc import ABCMeta +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch.ao.quantization.observer import ( + AffineQuantizedObserverBase, + get_block_size, + MappingType, + TorchAODType, + ZeroPointDomain, +) +from torch.fx import Node + + +ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: + +logger = logging.getLogger(__name__) + +FP8_TYPES = { + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, +} +_SUB_BYTE_UINT_BOUNDS = { + torch.uint1: (0, 2**1 - 1), + torch.uint2: (0, 2**2 - 1), + torch.uint3: (0, 2**3 - 1), + torch.uint4: (0, 2**4 - 1), + torch.uint5: (0, 2**5 - 1), + torch.uint6: (0, 2**6 - 1), + torch.uint7: (0, 2**7 - 1), +} + +""" +Map from dtype to the bound value of integers +TODO: maybe can replace this with call to torch.iinfo +""" +_DTYPE_TO_QVALUE_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int16: (-(2**15), 2**15 - 1), + torch.int32: (-(2**31), 2**31 - 1), +} +_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS) + + +def _is_float8_type(dtype: torch.dtype) -> bool: + fp8_types = { + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + } + return dtype in fp8_types + + +# TODO: decide on if we want to allow custom quant_min/quant_max here +def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): + """Get quant_min and quant_max args based on dtype and also + verify that they are within the range of possible quant_min/quant_max + for dtype + """ + if dtype in FP8_TYPES: + quant_min_lower_bound, quant_max_upper_bound = ( + torch.finfo(dtype).min, + torch.finfo(dtype).max, + ) + elif dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise ValueError(f"Unsupported dtype: {dtype}") + else: + quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] + if quant_min is None: + quant_min = quant_min_lower_bound + if quant_max is None: + quant_max = quant_max_upper_bound + + assert quant_min >= quant_min_lower_bound, ( + "quant_min out of bound for dtype, " + f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" + ) + + assert quant_max <= quant_max_upper_bound, ( + "quant_max out of bound for dtype, " + f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" + ) + return quant_min, quant_max + + +def _get_reduction_params(block_size, input_size): + """Given block_size and input size find the parameters for reduction: + + Output: + shape_for_reduction: the shape we use to `view` input to prepare it for reduction + reduction_dims: the dims we'll do reduction over + + Example:: + Input: + block_size: (3, 3, 2, 10) + input_size: (3, 3, 10, 10) + + Output: + shape_for_reduction: (3, 3, 5, 2, 10) + reduction_dim: [0, 1, 3, 4] + """ + assert len(block_size) == len(input_size) + shape_for_reduction = [] + reduction_dims = [] + cur_dim = 0 + for i in range(len(block_size)): + if block_size[i] != input_size[i] and block_size[i] > 1: + assert input_size[i] % block_size[i] == 0, ( + f"Expecting input size at {i} dimension: " + f"{input_size[i]} to be divisible by block_size at {i} dimension: {block_size[i]}" + ) + shape_for_reduction.append(input_size[i] // block_size[i]) + shape_for_reduction.append(block_size[i]) + # reduce over the block_size[i] dim + reduction_dims.append(cur_dim + 1) + cur_dim += 2 + else: + # block_size[i] == input_size[i] or block_size[i] == 1 + shape_for_reduction.append(input_size[i]) + # we only need to reduce over the dimension if block_size is greater than 1 + # otherwise it's already the same as reduced dimension + if block_size[i] != 1: + reduction_dims.append(cur_dim) + cur_dim += 1 + return shape_for_reduction, reduction_dims + + +def _register_custom_op(lib): + """This decorator is used to preserve some high level operators for torch.export.export + while still allow them to be decomposed for inductor path + + requirement: make sure `fn.__name__[1:]` is the operator name you want to register + + NOTE: This should be applied at the top, after all other decorators have been applied + NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input, + e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make + sense for downstream system (like executorch) to accept as well + + Example: + lib = torch.library.Library("my_namespace', "FRAGMENT") + + register_custom_op = _register_custom_op(lib) + + @register_custom_op + def _the_op_that_needs_to_be_preserved(...) + ... + + # after this, `_the_op_that_needs_to_be_preserved` will be preserved as + # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after + # torch.export.export / torch._export.export_for_training + + """ + from torch._inductor.decomposition import register_decomposition + + def decorator(fn): + from torch._library.infer_schema import infer_schema + + # expecting fn.__name__ starts with `_` and we want to take the rest + # to be the name of the custom op + assert ( + fn.__name__[0] == "_" + ), f"Expecting function name starts with `_`, got {fn.__name__}" + assert not any( + c in fn.__name__ for c in ".<>" + ), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" + op_name = fn.__name__[1:] + schema = op_name + infer_schema(fn, mutates_args={}) + lib.define(schema) + lib.impl(op_name, fn, "CompositeImplicitAutograd") + + lib_namespace = lib.ns + op = getattr(getattr(torch.ops, lib_namespace), op_name) + register_decomposition([op])(fn) + return op + + return decorator + + +quant_lib = torch.library.Library("quant", "FRAGMENT") # noqa: TOR901 + +register_custom_op = _register_custom_op(quant_lib) + + +def choose_qparams_affine_with_min_max( + min_val: torch.Tensor, + max_val: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, +) -> Tuple[torch.Tensor, torch.Tensor]: + """A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine` + operator that pass in min_val and max_val directly instead of deriving these from a single input. + This is used for observers in static quantization where min_val and max_val may be obtained through + tracking all the data in calibration data set. + + Args: + Mostly same as :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`. with one + difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val + and then scale/zero_point, we pass in min_val/max_val directly + """ + return _choose_qparams_affine( + None, + mapping_type.name, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain.name if zero_point_domain is not None else None, + min_val, + max_val, + ) + + +@register_custom_op +def _choose_qparams_affine( + input: Optional[torch.Tensor], + mapping_type: str, + block_size: List[int], + target_dtype: torch.dtype, + quant_min: Optional[Union[int, float, bool]] = None, + quant_max: Optional[Union[int, float, bool]] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[str] = "INT", + min_val: Optional[torch.Tensor] = None, + max_val: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """op definition that has compatible signatures with custom op library + + The op does the following: + 1. figure out the dimension for reduction based on block_size + 2. find min_val/max_val based on the dimension for reduction + 3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero` + and `zero_point_domain` + """ + quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) + assert mapping_type in [ + MappingType.SYMMETRIC.name, + MappingType.SYMMETRIC_NO_CLIPPING_ERR.name, + MappingType.ASYMMETRIC.name, + ], f"Unsupported mapping type: {mapping_type}" + if target_dtype in FP8_TYPES: + assert ( + mapping_type == MappingType.SYMMETRIC.name + ), f"Only symmetric quantization is supported for FP8 types, got {mapping_type}" + + if input is not None: + if scale_dtype is None: + scale_dtype = input.dtype + if zero_point_dtype is None: + zero_point_dtype = input.dtype + if eps is None: + eps = torch.finfo(input.dtype).eps + + assert ( + len(block_size) == input.dim() + ), f"Got input dim:{input.dim()}, block_size: {block_size}" + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + input = input.view(shape_for_reduction) + + min_val = torch.amin(input, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + else: + assert ( + min_val is not None and max_val is not None + ), "Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}" + assert ( + min_val.dtype == max_val.dtype + ), "Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}" + + if scale_dtype is None: + scale_dtype = min_val.dtype + if zero_point_dtype is None: + zero_point_dtype = min_val.dtype + if eps is None: + eps = torch.finfo(min_val.dtype).eps + + if preserve_zero: + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + else: + min_val_neg = min_val + max_val_pos = max_val + + if ( + mapping_type == MappingType.SYMMETRIC.name + or mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name + ): + # scales + if mapping_type == MappingType.SYMMETRIC.name: + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + else: + assert mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name + # calculate smin and smax individually and choose the larger one. For example, if quant_min = -8 and + # quant_max = 7. + # - If smin is bigger: There would be coverage on negative values down to -8, and less rounding + # error than the existing SYMMETRIC case. + # - If smax is bigger: it covers the positive values up to 7. The round + # error may be bigger than the existing SYMMETRIC case. Either way, there's no out-of-range fp values after + # quantization. + smin = min_val_neg / float(quant_min) + smax = max_val_pos / float(quant_max) + mask = smin > smax + scale = torch.where(mask, smin, smax) + # zeros + if not preserve_zero: + raise ValueError( + "preserve_zero == False is not supported for symmetric quantization" + ) + if ( + zero_point_domain is not None + and zero_point_domain != ZeroPointDomain.INT.name + ): + raise ValueError( + "zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization" + ) + scale = torch.clamp(scale, min=eps) + zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) + else: + assert mapping_type == MappingType.ASYMMETRIC.name + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.clamp(scale, min=eps) + if zero_point_domain == ZeroPointDomain.NONE.name: + zero_point = None + else: + if preserve_zero: + zero_point = quant_min - torch.round(min_val_neg / scale) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + else: + assert ( + zero_point_domain == ZeroPointDomain.FLOAT.name + ), "if not preserve_zero, zero_point must be in FLOAT domain" + mid_point = (quant_max + quant_min + 1) / 2 + zero_point = min_val_neg + scale * mid_point + + if zero_point is not None: + zero_point = zero_point.to(dtype=zero_point_dtype) + return scale.to(dtype=scale_dtype), zero_point + + +@torch.no_grad() +def quantize_affine( + input: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + output_dtype: torch.dtype, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, +) -> torch.Tensor: + """ + Args: + input (torch.Tensor): original float32, float16 or bfloat16 Tensor + block_size: (Tuple[int, ...]): granularity of quantization, + this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + scale (float): quantization parameter for affine quantization + zero_point (int): quantization parameter for affine quantization + output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype + quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT + + Note: + How can block_size represent different granularities? + let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different + granularities: + + granularity type | block_size + per_tensor | (3, 3, 10, 10) + per_axis (axis=0) | (1, 3, 10, 10) + per_axis (axis=1) | (3, 1, 10, 10) + per_group (groupsize=2) | (3, 3, 10, 2) + per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10) + + + Output: + quantized tensor with requested dtype + """ + return _quantize_affine( + input, + block_size, + scale, + zero_point, + output_dtype, + quant_min, + quant_max, + zero_point_domain.name if zero_point_domain is not None else None, + ) + + +@register_custom_op +def _quantize_affine( + input: torch.Tensor, + block_size: List[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + output_dtype: torch.dtype, + quant_min: Optional[Union[int, float, bool]] = None, + quant_max: Optional[Union[int, float, bool]] = None, + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, +) -> torch.Tensor: + """op definition that has compatible signatures with custom op library + + Note: + zero_point_domain is optional specifies how we quantize the floating point to quantized data: + INT: quantized_val = (float_val / scale) (integer) + zero_point (integer) + FLOAT: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale + None: quantized_val = (float_val / scale) | this is primarily used for floatx quantization + Where we do not want to round values to nearest integer and instead scale and cast. + """ + quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) + # workaround for uintx dtypes, since we don't have native Uintx dtype connected with + # torch.uintx dtypes yet + if output_dtype in _SUB_BYTE_UINT_BOUNDS: + output_dtype = torch.uint8 + return _quantize_affine_no_dtype_cast( + input, + block_size, + scale, + zero_point, + quant_min, + quant_max, + zero_point_domain, + ).to(output_dtype) + + +def _quantize_affine_no_dtype_cast( + input: torch.Tensor, + block_size: List[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + quant_min: Union[int, float], + quant_max: Union[int, float], + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, +) -> torch.Tensor: + """ + The op does the following: + 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + the shape after reduction + 2. quantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain + 3. reshape the quantized result to origianl shape + """ + # TODO: validations + # TODO: validate scale/zero_point dimensions are compatible with block_size + assert input.dtype in [ + torch.float32, + torch.float16, + torch.bfloat16, + ], f"Unsupported input dtype: {input.dtype}" + assert ( + len(block_size) == input.dim() + ), f"Got input dim:{input.dim()}, block_size: {block_size}" + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + original_shape = input.shape + input = input.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + if zero_point is not None: + zero_point = zero_point.view(shape_after_reduction) + + if zero_point_domain == ZeroPointDomain.INT.name: + quant = torch.clamp( + torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max + ) + elif zero_point_domain == ZeroPointDomain.NONE.name: + assert ( + zero_point is None + ), "zero_point should be None when zero_point_domain is NONE" + quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max) + elif zero_point_domain is None: + # This case handles quantization for float8 we expect no zero point and no zero point domain + assert ( + zero_point is None + ), "zero_point should be None when zero_point_domain is None" + quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + else: + assert zero_point_domain == ZeroPointDomain.FLOAT.name + mid_point = (quant_max + quant_min + 1) / 2 + min_val = zero_point - scale * mid_point + quant = torch.clamp( + torch.round((input - min_val) / scale), quant_min, quant_max + ) + quant = quant.view(original_shape) + + return quant + + +def dequantize_affine( + input: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + input_dtype: torch.dtype, + quant_min: Optional[Union[int, float]] = None, + quant_max: Optional[Union[int, float]] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + *, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Args: + input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument + block_size: (List[int]): granularity of quantization, + this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + scale (Tensor): quantization parameter for affine quantization + zero_point (Tensor): quantization parameter for affine quantization + input_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + quant_min (Optional[int]): minimum quantized value for input Tensor + quant_max (Optional[int]): maximum quantized value for input Tensor + output_dtype (torch.dtype): dtype for output Tensor, default is fp32 + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT + + Output: + dequantized Tensor, with requested dtype or fp32 + """ + return _dequantize_affine( + input, + block_size, + scale, + zero_point, + input_dtype, + quant_min, + quant_max, + zero_point_domain.name if zero_point_domain is not None else None, + output_dtype=output_dtype, + ) + + +@register_custom_op +def _dequantize_affine( + input: torch.Tensor, + block_size: List[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + input_dtype: torch.dtype, + quant_min: Optional[Union[int, float, bool]] = None, + quant_max: Optional[Union[int, float, bool]] = None, + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """op definition that has compatible signatures with custom op library""" + # TODO: validate scale/zero_point dimensions are compatible with block_size + if input_dtype not in _SUB_BYTE_UINT_BOUNDS: + assert ( + input.dtype == input_dtype + ), f"Expected: {input_dtype}, got: {input.dtype}" + assert output_dtype in [ + torch.float32, + torch.float16, + torch.bfloat16, + ], f"Unsupported output dtype: {output_dtype}" + quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) + return _dequantize_affine_no_dtype_check( + input, + block_size, + scale, + zero_point, + quant_min, + quant_max, + zero_point_domain, + output_dtype, + ) + + +def _dequantize_affine_no_dtype_check( + input: torch.Tensor, + block_size: List[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + quant_min: Union[int, float], + quant_max: Union[int, float], + zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """This function converts AQT tensors to their high precision floating point representation + + The op does the following: + 1. figure out the dimension for reduction based on block_size, also reshape the input to align with + the shape after reduction + 2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain + 3. reshape the quantized result to origianl shape and change dtype to the output_dtype + """ + assert ( + len(block_size) == input.dim() + ), f"Got input dim:{input.dim()}, block_size: {block_size}" + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input.size() + ) + original_shape = input.shape + input = input.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + + if zero_point is not None: + zero_point = zero_point.view(shape_after_reduction) + + if zero_point_domain == ZeroPointDomain.INT.name: + # Force a copy to avoid input modification due + # to upcoming in-place operations. + dequant = input.to(torch.int32, copy=True) + if zero_point is not None: + dequant = dequant - zero_point.to(torch.int32) + dequant = dequant.to(output_dtype) + dequant = dequant * scale + elif zero_point_domain == ZeroPointDomain.NONE.name: + assert ( + zero_point is None + ), "zero_point should be None when zero_point_domain is NONE" + dequant = input.to(output_dtype) + dequant = dequant * scale + elif zero_point_domain is None: + # This case handles dequantization for float8 we expect no zero point and no zero point domain + assert ( + zero_point is None + ), "zero_point should be None when zero_point_domain is None" + assert _is_float8_type( + input.dtype + ), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" + dequant = input.to(output_dtype) + dequant = dequant * scale + else: + assert ( + zero_point_domain == ZeroPointDomain.FLOAT.name + ), f"Unexpected zero point domain: {zero_point_domain}" + # TODO: this seems to be a detail for tinygemm (converting from uint to int, probably need to refactor this) + mid_point = (quant_max + quant_min + 1) / 2 + # This should allocate new memory and avoid input modification + dequant = input - mid_point + dequant = dequant.to(output_dtype) + dequant *= scale + if zero_point is not None: + dequant += zero_point + + return dequant.view(original_shape).to(output_dtype) + + +class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase): + def forward(self, input: torch.Tensor): + if input.numel() == 0: + return input + + input_detached = input.detach() + self.original_dtype = input_detached.dtype + assert self.granularity is not None, "granularity is None" + self.block_size = get_block_size(input_detached.shape, self.granularity) + + shape_for_reduction, reduction_dims = _get_reduction_params( + self.block_size, input_detached.size() + ) + input_detached = input_detached.view(shape_for_reduction) + min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False) + if not hasattr(self, "min_val") or not hasattr(self, "max_val"): + self.min_val = min_val + self.max_val = max_val + else: + assert ( + self.min_val.shape == min_val.shape + ), f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" + assert ( + self.max_val.shape == max_val.shape + ), f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" + min_val = torch.min(self.min_val, min_val) + max_val = torch.max(self.max_val, max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + # returning original input + return input + + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + assert hasattr(self, "min_val") and hasattr( + self, "max_val" + ), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + return choose_qparams_affine_with_min_max( + self.min_val, + self.max_val, + self.mapping_type, + [], # BlockSize is not needed because the min/max are already reduced + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain, + ) + + def convert(self, model: torch.fx.GraphModule, observer_node: Node): + print("calling convert") + from torch.ao.quantization.fx.utils import create_getattr_from_value + + scale, zero_point = self.calculate_qparams() + with model.graph.inserting_before(observer_node): + assert self.block_size is not None, "Expecting block_size to be populated" + assert ( + self.original_dtype is not None + ), "Expecting original_dtype to be populated" + scale_node = create_getattr_from_value(model, model.graph, "_scale", scale) + zero_point_node = create_getattr_from_value( + model, model.graph, "_zero_point", zero_point + ) + q_node = model.graph.call_function( + torch.ops.quant.quantize_affine, + ( + observer_node.args[0], + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {}, + ) + dq_node = model.graph.call_function( + torch.ops.quant.dequantize_affine, + ( + q_node, + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {"output_dtype": self.original_dtype}, + ) + observer_node.replace_all_uses_with(dq_node) + model.graph.erase_node(observer_node) diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 7403224b15c..084bdef2d70 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -1305,6 +1305,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase): m = prepare_qat_pt2e(m, quantizer) else: m = prepare_pt2e(m, quantizer) + if is_debug_mode: + print("prepared model:", m) # Calibrate m(*example_inputs) m = convert_pt2e(m)