[quant][api] Move torch.ao.quantization.pt2e.quantizer to torch.ao.quantization.quantizer (#105885)

Summary: moving quantizer to torch.ao.quantization to make it a public api, since pt2e is a folder for implementations

Test Plan:
CIs

sanity check: "buck test //executorch/backends/xnnpack/test:test_xnnpack_quantized_models -- test_resnet18"

Differential Revision: D47727838

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105885
Approved by: https://github.com/andrewor14
This commit is contained in:
Jerry Zhang 2023-07-26 18:20:09 +00:00 committed by PyTorch MergeBot
parent 70b0f1b248
commit 3a77f9aaaf
17 changed files with 146 additions and 105 deletions

View File

@ -120,11 +120,15 @@ This module contains a few CustomConfig classes that's used in both eager mode a
ConvertCustomConfig
StandaloneModuleConfigEntry
torch.ao.quantization.pt2e (quantization in pytorch 2.0 export)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
torch.ao.quantization.quantizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: torch.ao.quantization.quantizer
torch.ao.quantization.pt2e (quantization in pytorch 2.0 export implementation)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: torch.ao.quantization.pt2e
.. automodule:: torch.ao.quantization.pt2e.quantizer
.. automodule:: torch.ao.quantization.pt2e.representation
torch (quantization related functions)

View File

@ -12,13 +12,13 @@ import weakref
import torch
import torch._dynamo as torchdynamo
import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch import nn
from torch._inductor import config
from torch._inductor.compile_fx import compile_fx
from torch._inductor.utils import override_lowering, run_and_get_code
from torch.ao.quantization.pt2e.quantizer import X86InductorQuantizer
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import X86InductorQuantizer
from torch.testing import FileCheck
from torch.testing._internal.common_quantization import (
skipIfNoDynamoSupport,

View File

@ -15,7 +15,7 @@ from torch.ao.quantization import (
ObserverOrFakeQuantize,
QConfigMapping,
)
from torch.ao.quantization.pt2e.quantizer import (
from torch.ao.quantization.quantizer import (
ComposableQuantizer,
DerivedQuantizationSpec,
EmbeddingQuantizer,
@ -27,10 +27,10 @@ from torch.ao.quantization.pt2e.quantizer import (
Quantizer,
SharedQuantizationSpec,
)
from torch.ao.quantization.pt2e.quantizer.composable_quantizer import ( # noqa: F811
from torch.ao.quantization.quantizer.composable_quantizer import ( # noqa: F811
ComposableQuantizer,
)
from torch.ao.quantization.pt2e.quantizer.xnnpack_quantizer import (
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from torch.ao.quantization.quantize_pt2e import (

View File

@ -3,7 +3,7 @@ import copy
import torch
import torch._dynamo as torchdynamo
import torch.nn as nn
from torch.ao.quantization.pt2e.quantizer import (
from torch.ao.quantization.quantizer import (
X86InductorQuantizer,
)
from torch.ao.quantization.quantize_pt2e import (
@ -19,7 +19,7 @@ from torch.testing._internal.common_quantization import (
from torch.testing._internal.common_quantized import override_quantized_engine
from enum import Enum
import itertools
import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle

View File

@ -140,7 +140,7 @@ FILENAME_ALLOWLIST |= {torch.utils._foreach_utils.__file__}
# `torch.ao.quantization._pt2e`, which interferes with memory profiling
FILENAME_ALLOWLIST |= {
_module_dir(torch) + "ao/quantization/pt2e/qat_utils.py",
_module_dir(torch) + "ao/quantization/pt2e/quantizer/xnnpack_quantizer.py",
_module_dir(torch) + "ao/quantization/quantizer/xnnpack_quantizer.py",
_module_dir(torch) + "ao/quantization/pt2e/representation/rewrite.py",
_module_dir(torch) + "ao/quantization/pt2e/utils.py",
}

View File

@ -106,7 +106,7 @@ from .custom_config import (
PrepareCustomConfig,
StandaloneModuleConfigEntry,
)
from torch.ao.quantization.pt2e.quantizer import (
from torch.ao.quantization.quantizer import (
EdgeOrNode,
QuantizationSpec,
FixedQParamsQuantizationSpec,

View File

@ -1,7 +1,7 @@
from typing import Callable
import torch
from torch.ao.quantization.pt2e.quantizer import (
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
SharedQuantizationSpec,
)

View File

@ -19,7 +19,7 @@ from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.qconfig import QConfigAny
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
from typing import Dict, Tuple, Union, Any
from torch.ao.quantization.pt2e.quantizer import (
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
EdgeOrNode,
)

View File

@ -8,7 +8,7 @@ from torch.fx import Graph, GraphModule, Node
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
import torch.nn.functional as F
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from .quantizer import (
from torch.ao.quantization.quantizer import (
DerivedQuantizationSpec,
EdgeOrNode,
SharedQuantizationSpec,

View File

@ -15,8 +15,7 @@ from .pt2e.representation import reference_representation_rewrite
from .fx.prepare import prepare as fx_prepare
from .quantize_fx import _convert_to_reference_decomposed_fx
from torch.ao.quantization import QConfigMapping
# TODO: move quantizer to torch.ao.quantization
from torch.ao.quantization.pt2e.quantizer import ( # noqa: F401
from torch.ao.quantization.quantizer import ( # noqa: F401
OperatorConfig,
OperatorPatternType,
QuantizationConfig,
@ -31,13 +30,13 @@ from torch.ao.quantization.pt2e.quantizer import ( # noqa: F401
EmbeddingQuantizer,
ComposableQuantizer,
)
from torch.ao.quantization.pt2e.quantizer.utils import ( # noqa: F401
from torch.ao.quantization.quantizer.utils import ( # noqa: F401
get_bias_qspec,
get_input_act_qspec,
get_output_act_qspec,
get_weight_qspec,
)
from torch.ao.quantization.pt2e.quantizer.xnnpack_quantizer import ( # noqa: F401
from torch.ao.quantization.quantizer.xnnpack_quantizer import ( # noqa: F401
get_symmetric_quantization_config,
)
from torch.ao.quantization.backend_config import BackendConfig

View File

@ -1,4 +1,5 @@
from .xnnpack_quantizer import XNNPACKQuantizer
from .composable_quantizer import ComposableQuantizer
from .embedding_quantizer import EmbeddingQuantizer
from .quantizer import (
DerivedQuantizationSpec,
EdgeOrNode,
@ -6,16 +7,14 @@ from .quantizer import (
OperatorConfig,
OperatorPatternType,
QuantizationAnnotation,
QuantizationConfig,
QuantizationSpec,
QuantizationSpecBase,
Quantizer,
SharedQuantizationSpec,
QuantizationConfig,
)
from .x86_inductor_quantizer import X86InductorQuantizer
from .composable_quantizer import ComposableQuantizer
from .embedding_quantizer import EmbeddingQuantizer
from .xnnpack_quantizer import XNNPACKQuantizer
__all__ = [
"ComposableQuantizer",

View File

@ -5,7 +5,8 @@ from typing import List, Set
import torch
import torch.nn.functional as F
from torch.ao.quantization.pt2e.quantizer.quantizer import (
from torch.ao.quantization.observer import PerChannelMinMaxObserver
from torch.ao.quantization.quantizer.quantizer import (
OperatorConfig,
OperatorPatternType,
QuantizationAnnotation,
@ -13,13 +14,13 @@ from torch.ao.quantization.pt2e.quantizer.quantizer import (
QuantizationSpec,
Quantizer,
)
from torch.ao.quantization.observer import PerChannelMinMaxObserver
__all__ = [
"get_embedding_operators_config",
"EmbeddingQuantizer",
]
def get_embedding_operators_config() -> OperatorConfig:
weight_quantization_spec = QuantizationSpec(
dtype=torch.uint8,

View File

@ -1,12 +1,12 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from torch.fx import Node
from typing import Callable, List, NamedTuple, Optional, Dict, Union, Tuple
from torch.ao.quantization import ObserverOrFakeQuantize
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch import Tensor
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.ao.quantization import ObserverOrFakeQuantize
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.fx import Node
__all__ = [
"Quantizer",
@ -32,17 +32,21 @@ SUPPORTED_QSCHEMES = [
torch.per_channel_affine_float_qparams,
]
class QuantizationSpecBase(ABC):
""" Base class for different types of quantization specs that allows users to
"""Base class for different types of quantization specs that allows users to
specify how to quantize a Tensor (input/output of a Node) in the model
"""
pass
@dataclass(eq=True, frozen=True)
class QuantizationSpec(QuantizationSpecBase):
""" Quantization spec for common operators that allows user to specify how to
"""Quantization spec for common operators that allows user to specify how to
quantize a Tensor, this includes dtype, quant_min, quant_max etc.
"""
dtype: torch.dtype
# observer or fake_quantize constructor such as
# MinMaxObserver, PerChannelHistogramObserver etc.
@ -79,6 +83,7 @@ class QuantizationSpec(QuantizationSpecBase):
if self.ch_axis is not None and self.ch_axis < 0:
raise ValueError("Ch_axis is < 0.")
@dataclass(eq=True, frozen=True)
class FixedQParamsQuantizationSpec(QuantizationSpecBase):
dtype: torch.dtype
@ -88,6 +93,7 @@ class FixedQParamsQuantizationSpec(QuantizationSpecBase):
quant_max: Optional[int] = None
qscheme: Optional[torch.qscheme] = None
"""
The way we refer to other points of quantization in the graph will be either
an input edge or an output value
@ -95,19 +101,22 @@ input edge is the connection between input node and the node consuming the input
output value is an fx Node
"""
EdgeOrNode = Union[Tuple[Node, Node], Node]
EdgeOrNode.__module__ = "torch.ao.quantization.pt2e.quantizer.quantizer"
EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer"
@dataclass(eq=True, frozen=True)
class SharedQuantizationSpec(QuantizationSpecBase):
"""
Quantization spec for the Tensors whose quantization parameters are shared with other Tensors
"""
edge_or_node: EdgeOrNode
@dataclass(eq=True, frozen=True)
class DerivedQuantizationSpec(QuantizationSpecBase):
""" quantization spec for the Tensors whose quantization parameters are derived from other Tensors
"""
"""Quantization spec for the Tensors whose quantization parameters are derived from other Tensors"""
derived_from: List[EdgeOrNode]
derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]]
dtype: torch.dtype
@ -115,6 +124,7 @@ class DerivedQuantizationSpec(QuantizationSpecBase):
quant_max: Optional[int] = None
qscheme: Optional[torch.qscheme] = None
# In the absence of better name, just winging it with QuantizationConfig
@dataclass(eq=True, frozen=True)
class QuantizationConfig:
@ -125,8 +135,10 @@ class QuantizationConfig:
# TODO: remove, since we can use observer_or_fake_quant_ctr to express this
is_qat: bool = False
OperatorPatternType = List[Callable]
OperatorPatternType.__module__ = "torch.ao.quantization.pt2e.quantizer.quantizer"
OperatorPatternType.__module__ = "torch.ao.quantization.quantizer.quantizer"
class OperatorConfig(NamedTuple):
# fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]]
@ -140,9 +152,10 @@ class OperatorConfig(NamedTuple):
config: QuantizationConfig
operators: List[OperatorPatternType]
@dataclass
class QuantizationAnnotation:
""" How are input arguemnt or output should be quantized,
"""How are input arguemnt or output should be quantized,
expressed as QuantizationSpec, this corresponds to how a Tensor in the
operator Graph is observed (PTQ) or fake quantized (QAT)
"""
@ -157,8 +170,8 @@ class QuantizationAnnotation:
# whether the node is annotated or not
_annotated: bool = False
class Quantizer(ABC):
class Quantizer(ABC):
# annotate nodes in the graph with observer or fake quant constructors
# to convey the desired way of quantization
@abstractmethod

View File

@ -1,7 +1,7 @@
from typing import List
import torch
from torch.ao.quantization.pt2e.quantizer.quantizer import (
from torch.ao.quantization.quantizer.quantizer import (
QuantizationAnnotation,
QuantizationConfig,
QuantizationSpec,
@ -15,6 +15,7 @@ __all__ = [
"get_bias_qspec",
]
def get_input_act_qspec(quantization_config: QuantizationConfig):
if quantization_config is None:
return None
@ -91,11 +92,11 @@ def _annotate_output_qspec(node: Node, qspec):
def _is_sym_size_node(node: Node):
return (
node.op == "call_function" and
node.target == torch.ops.aten.sym_size.default or
node.target == torch.ops.aten.sym_numel.default or
node.target == torch.ops.aten.sym_numel or
node.target == torch.ops.aten.sym_size
node.op == "call_function"
and node.target == torch.ops.aten.sym_size.default
or node.target == torch.ops.aten.sym_numel.default
or node.target == torch.ops.aten.sym_numel
or node.target == torch.ops.aten.sym_size
)

View File

@ -1,45 +1,45 @@
import torch
import torch.nn.functional as F
import copy
import functools
import itertools
import operator
from .quantizer import (
OperatorConfig,
OperatorPatternType,
QuantizationConfig,
QuantizationSpec,
Quantizer,
QuantizationAnnotation,
from typing import Any, Dict, List, Optional, Set
import torch
import torch.nn.functional as F
from torch.ao.quantization.observer import (
HistogramObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.pt2e.quantizer.utils import (
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.ao.quantization.quantizer.utils import (
get_bias_qspec,
get_input_act_qspec,
get_output_act_qspec,
get_weight_qspec,
get_bias_qspec,
)
from .xnnpack_quantizer import (
_is_annotated,
)
from torch.ao.quantization.observer import (
HistogramObserver,
PlaceholderObserver,
PerChannelMinMaxObserver,
)
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from typing import List, Dict, Optional, Set, Any
from torch.fx import Node
from torch.fx.passes.utils.source_matcher_utils import (
get_source_partitions,
SourcePartition,
)
from .quantizer import (
OperatorConfig,
OperatorPatternType,
QuantizationAnnotation,
QuantizationConfig,
QuantizationSpec,
Quantizer,
)
from .xnnpack_quantizer import _is_annotated
__all__ = [
"X86InductorQuantizer",
"get_default_x86_inductor_quantization_config",
]
def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
# TODO: Add more supported operators here.
supported_operators: Dict[str, List[OperatorPatternType]] = {
@ -71,7 +71,9 @@ def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]:
supported_config_and_operators: List[OperatorConfig] = []
for quantization_config in [get_default_x86_inductor_quantization_config(), ]:
for quantization_config in [
get_default_x86_inductor_quantization_config(),
]:
ops = _supported_quantized_operators()
for pattern_list in ops.values():
supported_config_and_operators.append(
@ -82,8 +84,9 @@ def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]:
@functools.lru_cache
def get_default_x86_inductor_quantization_config():
act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \
act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
HistogramObserver
)
# Copy from x86 default qconfig from torch/ao/quantization/qconfig.py
act_quantization_spec = QuantizationSpec(
@ -92,10 +95,14 @@ def get_default_x86_inductor_quantization_config():
quant_max=255, # reduce_range=False
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12),
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
eps=2**-12
),
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PerChannelMinMaxObserver
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
PerChannelMinMaxObserver
)
extra_args: Dict[str, Any] = {"eps": 2**-12}
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
@ -104,18 +111,21 @@ def get_default_x86_inductor_quantization_config():
qscheme=torch.per_channel_symmetric,
ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args),
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
**extra_args
),
)
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
PlaceholderObserver
)
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver
bias_quantization_spec = QuantizationSpec(
dtype=torch.float,
observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
dtype=torch.float, observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
)
quantization_config = QuantizationConfig(
act_quantization_spec,
act_quantization_spec,
weight_quantization_spec,
bias_quantization_spec
bias_quantization_spec,
)
return quantization_config
@ -165,13 +175,12 @@ class X86InductorQuantizer(Quantizer):
return self
def _annotate_conv_node_helper(
self,
conv_node: torch.fx.Node,
annotate_output: bool,
quantization_config: QuantizationConfig,
) -> None :
""" Helper function to annotate the conv node
"""
self,
conv_node: torch.fx.Node,
annotate_output: bool,
quantization_config: QuantizationConfig,
) -> None:
"""Helper function to annotate the conv node"""
input_qspec_map = {}
input_node = conv_node.args[0]
assert isinstance(input_node, Node)
@ -187,20 +196,18 @@ class X86InductorQuantizer(Quantizer):
input_qspec_map=input_qspec_map,
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
output_qspec=get_output_act_qspec(quantization_config),
_annotated=True
_annotated=True,
)
else:
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True
input_qspec_map=input_qspec_map, _annotated=True
)
def _get_output_nodes_of_partitions(
self,
partition_list: List[SourcePartition],
) -> List[torch.fx.Node]:
""" Helper function to get the output node list from partition list
"""
"""Helper function to get the output node list from partition list"""
output_node_list = []
for partition in partition_list:
if len(partition.output_nodes) > 1:
@ -209,7 +216,9 @@ class X86InductorQuantizer(Quantizer):
assert isinstance(output_node, Node)
output_node_list.append(output_node)
if len(output_node_list) != len(partition_list):
raise ValueError("length of output_node_list should equal to length of partition_list")
raise ValueError(
"length of output_node_list should equal to length of partition_list"
)
return output_node_list
def _get_input_idx_for_binary_node(
@ -217,7 +226,7 @@ class X86InductorQuantizer(Quantizer):
conv_gemm_node: torch.fx.Node,
binary_node: torch.fx.Node,
):
""" Helper function to check conv_gemm and extra input node index
"""Helper function to check conv_gemm and extra input node index
for binary node fused with conv_gemm.
"""
conv_gemm_node_idx = None
@ -237,8 +246,7 @@ class X86InductorQuantizer(Quantizer):
return conv_gemm_node_idx, extra_input_node_idx
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
""" just handling global spec for now
"""
"""just handling global spec for now"""
model = self._annotate_for_static_quantization_config(model)
return model
@ -268,28 +276,34 @@ class X86InductorQuantizer(Quantizer):
conv_node, binary_node, unary_node = self._get_output_nodes_of_partitions(
[conv_partition, binary_partition, unary_partition]
)
conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(conv_node, binary_node)
conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(
conv_node, binary_node
)
if (conv_node_idx is None) or (extra_input_node_idx is None):
continue
if conv_node != binary_node.args[conv_node_idx]:
raise ValueError(f"{conv_node} doesn't match input of binary node")
extra_input_node = binary_node.args[extra_input_node_idx]
if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default:
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.convolution.default
):
# No conv node found to be fused with add
continue
if _is_annotated([unary_node, binary_node, conv_node]):
continue
self._annotate_conv_node_helper(conv_node, False, quantization_config)
binary_node_input_qspec_map = {}
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(quantization_config)
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
quantization_config
)
binary_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=binary_node_input_qspec_map,
_annotated=True
input_qspec_map=binary_node_input_qspec_map, _annotated=True
)
unary_node.meta["quantization_annotation"] = QuantizationAnnotation(
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True
_annotated=True,
)
def _annotate_conv2d_binary(
@ -304,26 +318,33 @@ class X86InductorQuantizer(Quantizer):
conv_node, binary_node = self._get_output_nodes_of_partitions(
[conv_partition, binary_partition]
)
conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(conv_node, binary_node)
conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(
conv_node, binary_node
)
if (conv_node_idx is None) or (extra_input_node_idx is None):
continue
if conv_node != binary_node.args[conv_node_idx]:
raise ValueError(f"{conv_node} doesn't match input of binary node")
extra_input_node = binary_node.args[extra_input_node_idx]
assert isinstance(conv_node, Node)
if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default:
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.convolution.default
):
# No conv node found to be fused with add
continue
if _is_annotated([binary_node, conv_node]):
continue
self._annotate_conv_node_helper(conv_node, False, quantization_config)
binary_node_input_qspec_map = {}
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(quantization_config)
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
quantization_config
)
binary_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=binary_node_input_qspec_map,
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True
_annotated=True,
)
def _annotate_conv2d_unary(
@ -337,7 +358,10 @@ class X86InductorQuantizer(Quantizer):
conv_node, unary_node = self._get_output_nodes_of_partitions(
[conv_partition, unary_partition]
)
if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default:
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.convolution.default
):
continue
if _is_annotated([unary_node, conv_node]):
continue
@ -345,7 +369,7 @@ class X86InductorQuantizer(Quantizer):
unary_node.meta["quantization_annotation"] = QuantizationAnnotation(
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True
_annotated=True,
)
def _annotate_conv2d(

View File

@ -21,8 +21,9 @@ from torch.ao.quantization.observer import (
)
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.ao.quantization.pt2e.quantizer.utils import (
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
_is_sym_size_node,
@ -32,7 +33,6 @@ from torch.ao.quantization.pt2e.quantizer.utils import (
get_output_act_qspec,
get_weight_qspec,
)
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.fx import Node
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions