diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst index ba3eac56a6e..f54888ec38c 100644 --- a/docs/source/quantization-support.rst +++ b/docs/source/quantization-support.rst @@ -120,11 +120,15 @@ This module contains a few CustomConfig classes that's used in both eager mode a ConvertCustomConfig StandaloneModuleConfigEntry -torch.ao.quantization.pt2e (quantization in pytorch 2.0 export) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +torch.ao.quantization.quantizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: torch.ao.quantization.quantizer + +torch.ao.quantization.pt2e (quantization in pytorch 2.0 export implementation) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: torch.ao.quantization.pt2e -.. automodule:: torch.ao.quantization.pt2e.quantizer .. automodule:: torch.ao.quantization.pt2e.representation torch (quantization related functions) diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py index ca88d1a0009..79b74be2fa6 100644 --- a/test/inductor/test_inductor_freezing.py +++ b/test/inductor/test_inductor_freezing.py @@ -12,13 +12,13 @@ import weakref import torch import torch._dynamo as torchdynamo -import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq from torch import nn from torch._inductor import config from torch._inductor.compile_fx import compile_fx from torch._inductor.utils import override_lowering, run_and_get_code -from torch.ao.quantization.pt2e.quantizer import X86InductorQuantizer from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer import X86InductorQuantizer from torch.testing import FileCheck from torch.testing._internal.common_quantization import ( skipIfNoDynamoSupport, diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 2fc1c28a9f8..537d3bf1511 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -15,7 +15,7 @@ from torch.ao.quantization import ( ObserverOrFakeQuantize, QConfigMapping, ) -from torch.ao.quantization.pt2e.quantizer import ( +from torch.ao.quantization.quantizer import ( ComposableQuantizer, DerivedQuantizationSpec, EmbeddingQuantizer, @@ -27,10 +27,10 @@ from torch.ao.quantization.pt2e.quantizer import ( Quantizer, SharedQuantizationSpec, ) -from torch.ao.quantization.pt2e.quantizer.composable_quantizer import ( # noqa: F811 +from torch.ao.quantization.quantizer.composable_quantizer import ( # noqa: F811 ComposableQuantizer, ) -from torch.ao.quantization.pt2e.quantizer.xnnpack_quantizer import ( +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, ) from torch.ao.quantization.quantize_pt2e import ( diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 6f02d6a94e4..7b21fef01f2 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -3,7 +3,7 @@ import copy import torch import torch._dynamo as torchdynamo import torch.nn as nn -from torch.ao.quantization.pt2e.quantizer import ( +from torch.ao.quantization.quantizer import ( X86InductorQuantizer, ) from torch.ao.quantization.quantize_pt2e import ( @@ -19,7 +19,7 @@ from torch.testing._internal.common_quantization import ( from torch.testing._internal.common_quantized import override_quantized_engine from enum import Enum import itertools -import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle diff --git a/torch/_dynamo/skipfiles.py b/torch/_dynamo/skipfiles.py index 5e78fce4987..dde56ffeb2a 100644 --- a/torch/_dynamo/skipfiles.py +++ b/torch/_dynamo/skipfiles.py @@ -140,7 +140,7 @@ FILENAME_ALLOWLIST |= {torch.utils._foreach_utils.__file__} # `torch.ao.quantization._pt2e`, which interferes with memory profiling FILENAME_ALLOWLIST |= { _module_dir(torch) + "ao/quantization/pt2e/qat_utils.py", - _module_dir(torch) + "ao/quantization/pt2e/quantizer/xnnpack_quantizer.py", + _module_dir(torch) + "ao/quantization/quantizer/xnnpack_quantizer.py", _module_dir(torch) + "ao/quantization/pt2e/representation/rewrite.py", _module_dir(torch) + "ao/quantization/pt2e/utils.py", } diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 803a52c4b57..6441e03f2a6 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -106,7 +106,7 @@ from .custom_config import ( PrepareCustomConfig, StandaloneModuleConfigEntry, ) -from torch.ao.quantization.pt2e.quantizer import ( +from torch.ao.quantization.quantizer import ( EdgeOrNode, QuantizationSpec, FixedQParamsQuantizationSpec, diff --git a/torch/ao/quantization/pt2e/_propagate_annotation.py b/torch/ao/quantization/pt2e/_propagate_annotation.py index 2dc8bc4f10d..1b2fe5eda7d 100644 --- a/torch/ao/quantization/pt2e/_propagate_annotation.py +++ b/torch/ao/quantization/pt2e/_propagate_annotation.py @@ -1,7 +1,7 @@ from typing import Callable import torch -from torch.ao.quantization.pt2e.quantizer import ( +from torch.ao.quantization.quantizer import ( QuantizationAnnotation, SharedQuantizationSpec, ) diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index 13f73acca35..82738eb1597 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -19,7 +19,7 @@ from torch.ao.quantization import QConfigMapping from torch.ao.quantization.qconfig import QConfigAny from torch.ao.quantization.fx.custom_config import PrepareCustomConfig from typing import Dict, Tuple, Union, Any -from torch.ao.quantization.pt2e.quantizer import ( +from torch.ao.quantization.quantizer import ( QuantizationAnnotation, EdgeOrNode, ) diff --git a/torch/ao/quantization/pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py index c4665bb24a2..837e19b1602 100644 --- a/torch/ao/quantization/pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -8,7 +8,7 @@ from torch.fx import Graph, GraphModule, Node from torch.fx.subgraph_rewriter import replace_pattern_with_filters import torch.nn.functional as F from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 -from .quantizer import ( +from torch.ao.quantization.quantizer import ( DerivedQuantizationSpec, EdgeOrNode, SharedQuantizationSpec, diff --git a/torch/ao/quantization/quantize_pt2e.py b/torch/ao/quantization/quantize_pt2e.py index 7ee7ea7e58e..985181dda44 100644 --- a/torch/ao/quantization/quantize_pt2e.py +++ b/torch/ao/quantization/quantize_pt2e.py @@ -15,8 +15,7 @@ from .pt2e.representation import reference_representation_rewrite from .fx.prepare import prepare as fx_prepare from .quantize_fx import _convert_to_reference_decomposed_fx from torch.ao.quantization import QConfigMapping -# TODO: move quantizer to torch.ao.quantization -from torch.ao.quantization.pt2e.quantizer import ( # noqa: F401 +from torch.ao.quantization.quantizer import ( # noqa: F401 OperatorConfig, OperatorPatternType, QuantizationConfig, @@ -31,13 +30,13 @@ from torch.ao.quantization.pt2e.quantizer import ( # noqa: F401 EmbeddingQuantizer, ComposableQuantizer, ) -from torch.ao.quantization.pt2e.quantizer.utils import ( # noqa: F401 +from torch.ao.quantization.quantizer.utils import ( # noqa: F401 get_bias_qspec, get_input_act_qspec, get_output_act_qspec, get_weight_qspec, ) -from torch.ao.quantization.pt2e.quantizer.xnnpack_quantizer import ( # noqa: F401 +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( # noqa: F401 get_symmetric_quantization_config, ) from torch.ao.quantization.backend_config import BackendConfig diff --git a/torch/ao/quantization/pt2e/quantizer/__init__.py b/torch/ao/quantization/quantizer/__init__.py similarity index 99% rename from torch/ao/quantization/pt2e/quantizer/__init__.py rename to torch/ao/quantization/quantizer/__init__.py index 9a8f140f7b4..72ff406f829 100644 --- a/torch/ao/quantization/pt2e/quantizer/__init__.py +++ b/torch/ao/quantization/quantizer/__init__.py @@ -1,4 +1,5 @@ -from .xnnpack_quantizer import XNNPACKQuantizer +from .composable_quantizer import ComposableQuantizer +from .embedding_quantizer import EmbeddingQuantizer from .quantizer import ( DerivedQuantizationSpec, EdgeOrNode, @@ -6,16 +7,14 @@ from .quantizer import ( OperatorConfig, OperatorPatternType, QuantizationAnnotation, + QuantizationConfig, QuantizationSpec, QuantizationSpecBase, Quantizer, SharedQuantizationSpec, - QuantizationConfig, ) from .x86_inductor_quantizer import X86InductorQuantizer - -from .composable_quantizer import ComposableQuantizer -from .embedding_quantizer import EmbeddingQuantizer +from .xnnpack_quantizer import XNNPACKQuantizer __all__ = [ "ComposableQuantizer", diff --git a/torch/ao/quantization/pt2e/quantizer/composable_quantizer.py b/torch/ao/quantization/quantizer/composable_quantizer.py similarity index 100% rename from torch/ao/quantization/pt2e/quantizer/composable_quantizer.py rename to torch/ao/quantization/quantizer/composable_quantizer.py diff --git a/torch/ao/quantization/pt2e/quantizer/embedding_quantizer.py b/torch/ao/quantization/quantizer/embedding_quantizer.py similarity index 98% rename from torch/ao/quantization/pt2e/quantizer/embedding_quantizer.py rename to torch/ao/quantization/quantizer/embedding_quantizer.py index 1aebcf2235c..35d32a7a265 100644 --- a/torch/ao/quantization/pt2e/quantizer/embedding_quantizer.py +++ b/torch/ao/quantization/quantizer/embedding_quantizer.py @@ -5,7 +5,8 @@ from typing import List, Set import torch import torch.nn.functional as F -from torch.ao.quantization.pt2e.quantizer.quantizer import ( +from torch.ao.quantization.observer import PerChannelMinMaxObserver +from torch.ao.quantization.quantizer.quantizer import ( OperatorConfig, OperatorPatternType, QuantizationAnnotation, @@ -13,13 +14,13 @@ from torch.ao.quantization.pt2e.quantizer.quantizer import ( QuantizationSpec, Quantizer, ) -from torch.ao.quantization.observer import PerChannelMinMaxObserver __all__ = [ "get_embedding_operators_config", "EmbeddingQuantizer", ] + def get_embedding_operators_config() -> OperatorConfig: weight_quantization_spec = QuantizationSpec( dtype=torch.uint8, diff --git a/torch/ao/quantization/pt2e/quantizer/quantizer.py b/torch/ao/quantization/quantizer/quantizer.py similarity index 91% rename from torch/ao/quantization/pt2e/quantizer/quantizer.py rename to torch/ao/quantization/quantizer/quantizer.py index bd7d0563455..4d8f9f3f9da 100644 --- a/torch/ao/quantization/pt2e/quantizer/quantizer.py +++ b/torch/ao/quantization/quantizer/quantizer.py @@ -1,12 +1,12 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from torch.fx import Node -from typing import Callable, List, NamedTuple, Optional, Dict, Union, Tuple -from torch.ao.quantization import ObserverOrFakeQuantize -from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor -from torch import Tensor +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union import torch +from torch import Tensor +from torch.ao.quantization import ObserverOrFakeQuantize +from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor +from torch.fx import Node __all__ = [ "Quantizer", @@ -32,17 +32,21 @@ SUPPORTED_QSCHEMES = [ torch.per_channel_affine_float_qparams, ] + class QuantizationSpecBase(ABC): - """ Base class for different types of quantization specs that allows users to + """Base class for different types of quantization specs that allows users to specify how to quantize a Tensor (input/output of a Node) in the model """ + pass + @dataclass(eq=True, frozen=True) class QuantizationSpec(QuantizationSpecBase): - """ Quantization spec for common operators that allows user to specify how to + """Quantization spec for common operators that allows user to specify how to quantize a Tensor, this includes dtype, quant_min, quant_max etc. """ + dtype: torch.dtype # observer or fake_quantize constructor such as # MinMaxObserver, PerChannelHistogramObserver etc. @@ -79,6 +83,7 @@ class QuantizationSpec(QuantizationSpecBase): if self.ch_axis is not None and self.ch_axis < 0: raise ValueError("Ch_axis is < 0.") + @dataclass(eq=True, frozen=True) class FixedQParamsQuantizationSpec(QuantizationSpecBase): dtype: torch.dtype @@ -88,6 +93,7 @@ class FixedQParamsQuantizationSpec(QuantizationSpecBase): quant_max: Optional[int] = None qscheme: Optional[torch.qscheme] = None + """ The way we refer to other points of quantization in the graph will be either an input edge or an output value @@ -95,19 +101,22 @@ input edge is the connection between input node and the node consuming the input output value is an fx Node """ EdgeOrNode = Union[Tuple[Node, Node], Node] -EdgeOrNode.__module__ = "torch.ao.quantization.pt2e.quantizer.quantizer" +EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer" + @dataclass(eq=True, frozen=True) class SharedQuantizationSpec(QuantizationSpecBase): """ Quantization spec for the Tensors whose quantization parameters are shared with other Tensors """ + edge_or_node: EdgeOrNode + @dataclass(eq=True, frozen=True) class DerivedQuantizationSpec(QuantizationSpecBase): - """ quantization spec for the Tensors whose quantization parameters are derived from other Tensors - """ + """Quantization spec for the Tensors whose quantization parameters are derived from other Tensors""" + derived_from: List[EdgeOrNode] derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]] dtype: torch.dtype @@ -115,6 +124,7 @@ class DerivedQuantizationSpec(QuantizationSpecBase): quant_max: Optional[int] = None qscheme: Optional[torch.qscheme] = None + # In the absence of better name, just winging it with QuantizationConfig @dataclass(eq=True, frozen=True) class QuantizationConfig: @@ -125,8 +135,10 @@ class QuantizationConfig: # TODO: remove, since we can use observer_or_fake_quant_ctr to express this is_qat: bool = False + OperatorPatternType = List[Callable] -OperatorPatternType.__module__ = "torch.ao.quantization.pt2e.quantizer.quantizer" +OperatorPatternType.__module__ = "torch.ao.quantization.quantizer.quantizer" + class OperatorConfig(NamedTuple): # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]] @@ -140,9 +152,10 @@ class OperatorConfig(NamedTuple): config: QuantizationConfig operators: List[OperatorPatternType] + @dataclass class QuantizationAnnotation: - """ How are input arguemnt or output should be quantized, + """How are input arguemnt or output should be quantized, expressed as QuantizationSpec, this corresponds to how a Tensor in the operator Graph is observed (PTQ) or fake quantized (QAT) """ @@ -157,8 +170,8 @@ class QuantizationAnnotation: # whether the node is annotated or not _annotated: bool = False -class Quantizer(ABC): +class Quantizer(ABC): # annotate nodes in the graph with observer or fake quant constructors # to convey the desired way of quantization @abstractmethod diff --git a/torch/ao/quantization/pt2e/quantizer/utils.py b/torch/ao/quantization/quantizer/utils.py similarity index 92% rename from torch/ao/quantization/pt2e/quantizer/utils.py rename to torch/ao/quantization/quantizer/utils.py index aada4c22c84..316741ab302 100644 --- a/torch/ao/quantization/pt2e/quantizer/utils.py +++ b/torch/ao/quantization/quantizer/utils.py @@ -1,7 +1,7 @@ from typing import List import torch -from torch.ao.quantization.pt2e.quantizer.quantizer import ( +from torch.ao.quantization.quantizer.quantizer import ( QuantizationAnnotation, QuantizationConfig, QuantizationSpec, @@ -15,6 +15,7 @@ __all__ = [ "get_bias_qspec", ] + def get_input_act_qspec(quantization_config: QuantizationConfig): if quantization_config is None: return None @@ -91,11 +92,11 @@ def _annotate_output_qspec(node: Node, qspec): def _is_sym_size_node(node: Node): return ( - node.op == "call_function" and - node.target == torch.ops.aten.sym_size.default or - node.target == torch.ops.aten.sym_numel.default or - node.target == torch.ops.aten.sym_numel or - node.target == torch.ops.aten.sym_size + node.op == "call_function" + and node.target == torch.ops.aten.sym_size.default + or node.target == torch.ops.aten.sym_numel.default + or node.target == torch.ops.aten.sym_numel + or node.target == torch.ops.aten.sym_size ) diff --git a/torch/ao/quantization/pt2e/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py similarity index 86% rename from torch/ao/quantization/pt2e/quantizer/x86_inductor_quantizer.py rename to torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 36f7785b162..cdce4923ad6 100644 --- a/torch/ao/quantization/pt2e/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -1,45 +1,45 @@ -import torch -import torch.nn.functional as F import copy import functools import itertools import operator -from .quantizer import ( - OperatorConfig, - OperatorPatternType, - QuantizationConfig, - QuantizationSpec, - Quantizer, - QuantizationAnnotation, +from typing import Any, Dict, List, Optional, Set + +import torch +import torch.nn.functional as F +from torch.ao.quantization.observer import ( + HistogramObserver, + PerChannelMinMaxObserver, + PlaceholderObserver, ) from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions -from torch.ao.quantization.pt2e.quantizer.utils import ( +from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor +from torch.ao.quantization.quantizer.utils import ( + get_bias_qspec, get_input_act_qspec, get_output_act_qspec, get_weight_qspec, - get_bias_qspec, ) -from .xnnpack_quantizer import ( - _is_annotated, -) -from torch.ao.quantization.observer import ( - HistogramObserver, - PlaceholderObserver, - PerChannelMinMaxObserver, -) -from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor -from typing import List, Dict, Optional, Set, Any from torch.fx import Node from torch.fx.passes.utils.source_matcher_utils import ( get_source_partitions, SourcePartition, ) +from .quantizer import ( + OperatorConfig, + OperatorPatternType, + QuantizationAnnotation, + QuantizationConfig, + QuantizationSpec, + Quantizer, +) +from .xnnpack_quantizer import _is_annotated __all__ = [ "X86InductorQuantizer", "get_default_x86_inductor_quantization_config", ] + def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]: # TODO: Add more supported operators here. supported_operators: Dict[str, List[OperatorPatternType]] = { @@ -71,7 +71,9 @@ def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]: def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]: supported_config_and_operators: List[OperatorConfig] = [] - for quantization_config in [get_default_x86_inductor_quantization_config(), ]: + for quantization_config in [ + get_default_x86_inductor_quantization_config(), + ]: ops = _supported_quantized_operators() for pattern_list in ops.values(): supported_config_and_operators.append( @@ -82,8 +84,9 @@ def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]: @functools.lru_cache def get_default_x86_inductor_quantization_config(): - act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \ + act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( HistogramObserver + ) # Copy from x86 default qconfig from torch/ao/quantization/qconfig.py act_quantization_spec = QuantizationSpec( @@ -92,10 +95,14 @@ def get_default_x86_inductor_quantization_config(): quant_max=255, # reduce_range=False qscheme=torch.per_tensor_affine, is_dynamic=False, - observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12), + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + eps=2**-12 + ), ) - weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PerChannelMinMaxObserver + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + PerChannelMinMaxObserver + ) extra_args: Dict[str, Any] = {"eps": 2**-12} weight_quantization_spec = QuantizationSpec( dtype=torch.int8, @@ -104,18 +111,21 @@ def get_default_x86_inductor_quantization_config(): qscheme=torch.per_channel_symmetric, ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv is_dynamic=False, - observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args), + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + PlaceholderObserver ) - bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver bias_quantization_spec = QuantizationSpec( - dtype=torch.float, - observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr + dtype=torch.float, observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr ) quantization_config = QuantizationConfig( act_quantization_spec, act_quantization_spec, weight_quantization_spec, - bias_quantization_spec + bias_quantization_spec, ) return quantization_config @@ -165,13 +175,12 @@ class X86InductorQuantizer(Quantizer): return self def _annotate_conv_node_helper( - self, - conv_node: torch.fx.Node, - annotate_output: bool, - quantization_config: QuantizationConfig, - ) -> None : - """ Helper function to annotate the conv node - """ + self, + conv_node: torch.fx.Node, + annotate_output: bool, + quantization_config: QuantizationConfig, + ) -> None: + """Helper function to annotate the conv node""" input_qspec_map = {} input_node = conv_node.args[0] assert isinstance(input_node, Node) @@ -187,20 +196,18 @@ class X86InductorQuantizer(Quantizer): input_qspec_map=input_qspec_map, # TODO Remove the annotate of output when oneDNN qconv support fp32 out. output_qspec=get_output_act_qspec(quantization_config), - _annotated=True + _annotated=True, ) else: conv_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - _annotated=True + input_qspec_map=input_qspec_map, _annotated=True ) def _get_output_nodes_of_partitions( self, partition_list: List[SourcePartition], ) -> List[torch.fx.Node]: - """ Helper function to get the output node list from partition list - """ + """Helper function to get the output node list from partition list""" output_node_list = [] for partition in partition_list: if len(partition.output_nodes) > 1: @@ -209,7 +216,9 @@ class X86InductorQuantizer(Quantizer): assert isinstance(output_node, Node) output_node_list.append(output_node) if len(output_node_list) != len(partition_list): - raise ValueError("length of output_node_list should equal to length of partition_list") + raise ValueError( + "length of output_node_list should equal to length of partition_list" + ) return output_node_list def _get_input_idx_for_binary_node( @@ -217,7 +226,7 @@ class X86InductorQuantizer(Quantizer): conv_gemm_node: torch.fx.Node, binary_node: torch.fx.Node, ): - """ Helper function to check conv_gemm and extra input node index + """Helper function to check conv_gemm and extra input node index for binary node fused with conv_gemm. """ conv_gemm_node_idx = None @@ -237,8 +246,7 @@ class X86InductorQuantizer(Quantizer): return conv_gemm_node_idx, extra_input_node_idx def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - """ just handling global spec for now - """ + """just handling global spec for now""" model = self._annotate_for_static_quantization_config(model) return model @@ -268,28 +276,34 @@ class X86InductorQuantizer(Quantizer): conv_node, binary_node, unary_node = self._get_output_nodes_of_partitions( [conv_partition, binary_partition, unary_partition] ) - conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(conv_node, binary_node) + conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node( + conv_node, binary_node + ) if (conv_node_idx is None) or (extra_input_node_idx is None): continue if conv_node != binary_node.args[conv_node_idx]: raise ValueError(f"{conv_node} doesn't match input of binary node") extra_input_node = binary_node.args[extra_input_node_idx] - if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default: + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.convolution.default + ): # No conv node found to be fused with add continue if _is_annotated([unary_node, binary_node, conv_node]): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) binary_node_input_qspec_map = {} - binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(quantization_config) + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) binary_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=binary_node_input_qspec_map, - _annotated=True + input_qspec_map=binary_node_input_qspec_map, _annotated=True ) unary_node.meta["quantization_annotation"] = QuantizationAnnotation( # TODO Remove the annotate of output when oneDNN qconv support fp32 out. output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True + _annotated=True, ) def _annotate_conv2d_binary( @@ -304,26 +318,33 @@ class X86InductorQuantizer(Quantizer): conv_node, binary_node = self._get_output_nodes_of_partitions( [conv_partition, binary_partition] ) - conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(conv_node, binary_node) + conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node( + conv_node, binary_node + ) if (conv_node_idx is None) or (extra_input_node_idx is None): continue if conv_node != binary_node.args[conv_node_idx]: raise ValueError(f"{conv_node} doesn't match input of binary node") extra_input_node = binary_node.args[extra_input_node_idx] assert isinstance(conv_node, Node) - if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default: + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.convolution.default + ): # No conv node found to be fused with add continue if _is_annotated([binary_node, conv_node]): continue self._annotate_conv_node_helper(conv_node, False, quantization_config) binary_node_input_qspec_map = {} - binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(quantization_config) + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) binary_node.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map=binary_node_input_qspec_map, # TODO Remove the annotate of output when oneDNN qconv support fp32 out. output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True + _annotated=True, ) def _annotate_conv2d_unary( @@ -337,7 +358,10 @@ class X86InductorQuantizer(Quantizer): conv_node, unary_node = self._get_output_nodes_of_partitions( [conv_partition, unary_partition] ) - if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default: + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.convolution.default + ): continue if _is_annotated([unary_node, conv_node]): continue @@ -345,7 +369,7 @@ class X86InductorQuantizer(Quantizer): unary_node.meta["quantization_annotation"] = QuantizationAnnotation( # TODO Remove the annotate of output when oneDNN qconv support fp32 out. output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] - _annotated=True + _annotated=True, ) def _annotate_conv2d( diff --git a/torch/ao/quantization/pt2e/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py similarity index 99% rename from torch/ao/quantization/pt2e/quantizer/xnnpack_quantizer.py rename to torch/ao/quantization/quantizer/xnnpack_quantizer.py index bda553d03b0..bd6df9aa5b1 100644 --- a/torch/ao/quantization/pt2e/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -21,8 +21,9 @@ from torch.ao.quantization.observer import ( ) from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions +from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor -from torch.ao.quantization.pt2e.quantizer.utils import ( +from torch.ao.quantization.quantizer.utils import ( _annotate_input_qspec_map, _annotate_output_qspec, _is_sym_size_node, @@ -32,7 +33,6 @@ from torch.ao.quantization.pt2e.quantizer.utils import ( get_output_act_qspec, get_weight_qspec, ) -from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor from torch.fx import Node from torch.fx.passes.utils.source_matcher_utils import get_source_partitions