mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][api] Move torch.ao.quantization.pt2e.quantizer to torch.ao.quantization.quantizer (#105885)
Summary: moving quantizer to torch.ao.quantization to make it a public api, since pt2e is a folder for implementations Test Plan: CIs sanity check: "buck test //executorch/backends/xnnpack/test:test_xnnpack_quantized_models -- test_resnet18" Differential Revision: D47727838 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105885 Approved by: https://github.com/andrewor14
This commit is contained in:
parent
70b0f1b248
commit
3a77f9aaaf
|
|
@ -120,11 +120,15 @@ This module contains a few CustomConfig classes that's used in both eager mode a
|
|||
ConvertCustomConfig
|
||||
StandaloneModuleConfigEntry
|
||||
|
||||
torch.ao.quantization.pt2e (quantization in pytorch 2.0 export)
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
torch.ao.quantization.quantizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: torch.ao.quantization.quantizer
|
||||
|
||||
torch.ao.quantization.pt2e (quantization in pytorch 2.0 export implementation)
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: torch.ao.quantization.pt2e
|
||||
.. automodule:: torch.ao.quantization.pt2e.quantizer
|
||||
.. automodule:: torch.ao.quantization.pt2e.representation
|
||||
|
||||
torch (quantization related functions)
|
||||
|
|
|
|||
|
|
@ -12,13 +12,13 @@ import weakref
|
|||
import torch
|
||||
|
||||
import torch._dynamo as torchdynamo
|
||||
import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
|
||||
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
|
||||
from torch import nn
|
||||
from torch._inductor import config
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
from torch._inductor.utils import override_lowering, run_and_get_code
|
||||
from torch.ao.quantization.pt2e.quantizer import X86InductorQuantizer
|
||||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
||||
from torch.ao.quantization.quantizer import X86InductorQuantizer
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_quantization import (
|
||||
skipIfNoDynamoSupport,
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from torch.ao.quantization import (
|
|||
ObserverOrFakeQuantize,
|
||||
QConfigMapping,
|
||||
)
|
||||
from torch.ao.quantization.pt2e.quantizer import (
|
||||
from torch.ao.quantization.quantizer import (
|
||||
ComposableQuantizer,
|
||||
DerivedQuantizationSpec,
|
||||
EmbeddingQuantizer,
|
||||
|
|
@ -27,10 +27,10 @@ from torch.ao.quantization.pt2e.quantizer import (
|
|||
Quantizer,
|
||||
SharedQuantizationSpec,
|
||||
)
|
||||
from torch.ao.quantization.pt2e.quantizer.composable_quantizer import ( # noqa: F811
|
||||
from torch.ao.quantization.quantizer.composable_quantizer import ( # noqa: F811
|
||||
ComposableQuantizer,
|
||||
)
|
||||
from torch.ao.quantization.pt2e.quantizer.xnnpack_quantizer import (
|
||||
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
get_symmetric_quantization_config,
|
||||
)
|
||||
from torch.ao.quantization.quantize_pt2e import (
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import copy
|
|||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
import torch.nn as nn
|
||||
from torch.ao.quantization.pt2e.quantizer import (
|
||||
from torch.ao.quantization.quantizer import (
|
||||
X86InductorQuantizer,
|
||||
)
|
||||
from torch.ao.quantization.quantize_pt2e import (
|
||||
|
|
@ -19,7 +19,7 @@ from torch.testing._internal.common_quantization import (
|
|||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
from enum import Enum
|
||||
import itertools
|
||||
import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
|
||||
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
|
||||
from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ FILENAME_ALLOWLIST |= {torch.utils._foreach_utils.__file__}
|
|||
# `torch.ao.quantization._pt2e`, which interferes with memory profiling
|
||||
FILENAME_ALLOWLIST |= {
|
||||
_module_dir(torch) + "ao/quantization/pt2e/qat_utils.py",
|
||||
_module_dir(torch) + "ao/quantization/pt2e/quantizer/xnnpack_quantizer.py",
|
||||
_module_dir(torch) + "ao/quantization/quantizer/xnnpack_quantizer.py",
|
||||
_module_dir(torch) + "ao/quantization/pt2e/representation/rewrite.py",
|
||||
_module_dir(torch) + "ao/quantization/pt2e/utils.py",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ from .custom_config import (
|
|||
PrepareCustomConfig,
|
||||
StandaloneModuleConfigEntry,
|
||||
)
|
||||
from torch.ao.quantization.pt2e.quantizer import (
|
||||
from torch.ao.quantization.quantizer import (
|
||||
EdgeOrNode,
|
||||
QuantizationSpec,
|
||||
FixedQParamsQuantizationSpec,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization.pt2e.quantizer import (
|
||||
from torch.ao.quantization.quantizer import (
|
||||
QuantizationAnnotation,
|
||||
SharedQuantizationSpec,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from torch.ao.quantization import QConfigMapping
|
|||
from torch.ao.quantization.qconfig import QConfigAny
|
||||
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
|
||||
from typing import Dict, Tuple, Union, Any
|
||||
from torch.ao.quantization.pt2e.quantizer import (
|
||||
from torch.ao.quantization.quantizer import (
|
||||
QuantizationAnnotation,
|
||||
EdgeOrNode,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from torch.fx import Graph, GraphModule, Node
|
|||
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
|
||||
import torch.nn.functional as F
|
||||
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
||||
from .quantizer import (
|
||||
from torch.ao.quantization.quantizer import (
|
||||
DerivedQuantizationSpec,
|
||||
EdgeOrNode,
|
||||
SharedQuantizationSpec,
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ from .pt2e.representation import reference_representation_rewrite
|
|||
from .fx.prepare import prepare as fx_prepare
|
||||
from .quantize_fx import _convert_to_reference_decomposed_fx
|
||||
from torch.ao.quantization import QConfigMapping
|
||||
# TODO: move quantizer to torch.ao.quantization
|
||||
from torch.ao.quantization.pt2e.quantizer import ( # noqa: F401
|
||||
from torch.ao.quantization.quantizer import ( # noqa: F401
|
||||
OperatorConfig,
|
||||
OperatorPatternType,
|
||||
QuantizationConfig,
|
||||
|
|
@ -31,13 +30,13 @@ from torch.ao.quantization.pt2e.quantizer import ( # noqa: F401
|
|||
EmbeddingQuantizer,
|
||||
ComposableQuantizer,
|
||||
)
|
||||
from torch.ao.quantization.pt2e.quantizer.utils import ( # noqa: F401
|
||||
from torch.ao.quantization.quantizer.utils import ( # noqa: F401
|
||||
get_bias_qspec,
|
||||
get_input_act_qspec,
|
||||
get_output_act_qspec,
|
||||
get_weight_qspec,
|
||||
)
|
||||
from torch.ao.quantization.pt2e.quantizer.xnnpack_quantizer import ( # noqa: F401
|
||||
from torch.ao.quantization.quantizer.xnnpack_quantizer import ( # noqa: F401
|
||||
get_symmetric_quantization_config,
|
||||
)
|
||||
from torch.ao.quantization.backend_config import BackendConfig
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from .xnnpack_quantizer import XNNPACKQuantizer
|
||||
from .composable_quantizer import ComposableQuantizer
|
||||
from .embedding_quantizer import EmbeddingQuantizer
|
||||
from .quantizer import (
|
||||
DerivedQuantizationSpec,
|
||||
EdgeOrNode,
|
||||
|
|
@ -6,16 +7,14 @@ from .quantizer import (
|
|||
OperatorConfig,
|
||||
OperatorPatternType,
|
||||
QuantizationAnnotation,
|
||||
QuantizationConfig,
|
||||
QuantizationSpec,
|
||||
QuantizationSpecBase,
|
||||
Quantizer,
|
||||
SharedQuantizationSpec,
|
||||
QuantizationConfig,
|
||||
)
|
||||
from .x86_inductor_quantizer import X86InductorQuantizer
|
||||
|
||||
from .composable_quantizer import ComposableQuantizer
|
||||
from .embedding_quantizer import EmbeddingQuantizer
|
||||
from .xnnpack_quantizer import XNNPACKQuantizer
|
||||
|
||||
__all__ = [
|
||||
"ComposableQuantizer",
|
||||
|
|
@ -5,7 +5,8 @@ from typing import List, Set
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.ao.quantization.pt2e.quantizer.quantizer import (
|
||||
from torch.ao.quantization.observer import PerChannelMinMaxObserver
|
||||
from torch.ao.quantization.quantizer.quantizer import (
|
||||
OperatorConfig,
|
||||
OperatorPatternType,
|
||||
QuantizationAnnotation,
|
||||
|
|
@ -13,13 +14,13 @@ from torch.ao.quantization.pt2e.quantizer.quantizer import (
|
|||
QuantizationSpec,
|
||||
Quantizer,
|
||||
)
|
||||
from torch.ao.quantization.observer import PerChannelMinMaxObserver
|
||||
|
||||
__all__ = [
|
||||
"get_embedding_operators_config",
|
||||
"EmbeddingQuantizer",
|
||||
]
|
||||
|
||||
|
||||
def get_embedding_operators_config() -> OperatorConfig:
|
||||
weight_quantization_spec = QuantizationSpec(
|
||||
dtype=torch.uint8,
|
||||
|
|
@ -1,12 +1,12 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from torch.fx import Node
|
||||
from typing import Callable, List, NamedTuple, Optional, Dict, Union, Tuple
|
||||
from torch.ao.quantization import ObserverOrFakeQuantize
|
||||
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
||||
from torch import Tensor
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.ao.quantization import ObserverOrFakeQuantize
|
||||
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
||||
from torch.fx import Node
|
||||
|
||||
__all__ = [
|
||||
"Quantizer",
|
||||
|
|
@ -32,17 +32,21 @@ SUPPORTED_QSCHEMES = [
|
|||
torch.per_channel_affine_float_qparams,
|
||||
]
|
||||
|
||||
|
||||
class QuantizationSpecBase(ABC):
|
||||
"""Base class for different types of quantization specs that allows users to
|
||||
specify how to quantize a Tensor (input/output of a Node) in the model
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class QuantizationSpec(QuantizationSpecBase):
|
||||
"""Quantization spec for common operators that allows user to specify how to
|
||||
quantize a Tensor, this includes dtype, quant_min, quant_max etc.
|
||||
"""
|
||||
|
||||
dtype: torch.dtype
|
||||
# observer or fake_quantize constructor such as
|
||||
# MinMaxObserver, PerChannelHistogramObserver etc.
|
||||
|
|
@ -79,6 +83,7 @@ class QuantizationSpec(QuantizationSpecBase):
|
|||
if self.ch_axis is not None and self.ch_axis < 0:
|
||||
raise ValueError("Ch_axis is < 0.")
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class FixedQParamsQuantizationSpec(QuantizationSpecBase):
|
||||
dtype: torch.dtype
|
||||
|
|
@ -88,6 +93,7 @@ class FixedQParamsQuantizationSpec(QuantizationSpecBase):
|
|||
quant_max: Optional[int] = None
|
||||
qscheme: Optional[torch.qscheme] = None
|
||||
|
||||
|
||||
"""
|
||||
The way we refer to other points of quantization in the graph will be either
|
||||
an input edge or an output value
|
||||
|
|
@ -95,19 +101,22 @@ input edge is the connection between input node and the node consuming the input
|
|||
output value is an fx Node
|
||||
"""
|
||||
EdgeOrNode = Union[Tuple[Node, Node], Node]
|
||||
EdgeOrNode.__module__ = "torch.ao.quantization.pt2e.quantizer.quantizer"
|
||||
EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer"
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class SharedQuantizationSpec(QuantizationSpecBase):
|
||||
"""
|
||||
Quantization spec for the Tensors whose quantization parameters are shared with other Tensors
|
||||
"""
|
||||
|
||||
edge_or_node: EdgeOrNode
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class DerivedQuantizationSpec(QuantizationSpecBase):
|
||||
""" quantization spec for the Tensors whose quantization parameters are derived from other Tensors
|
||||
"""
|
||||
"""Quantization spec for the Tensors whose quantization parameters are derived from other Tensors"""
|
||||
|
||||
derived_from: List[EdgeOrNode]
|
||||
derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]]
|
||||
dtype: torch.dtype
|
||||
|
|
@ -115,6 +124,7 @@ class DerivedQuantizationSpec(QuantizationSpecBase):
|
|||
quant_max: Optional[int] = None
|
||||
qscheme: Optional[torch.qscheme] = None
|
||||
|
||||
|
||||
# In the absence of better name, just winging it with QuantizationConfig
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class QuantizationConfig:
|
||||
|
|
@ -125,8 +135,10 @@ class QuantizationConfig:
|
|||
# TODO: remove, since we can use observer_or_fake_quant_ctr to express this
|
||||
is_qat: bool = False
|
||||
|
||||
|
||||
OperatorPatternType = List[Callable]
|
||||
OperatorPatternType.__module__ = "torch.ao.quantization.pt2e.quantizer.quantizer"
|
||||
OperatorPatternType.__module__ = "torch.ao.quantization.quantizer.quantizer"
|
||||
|
||||
|
||||
class OperatorConfig(NamedTuple):
|
||||
# fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]]
|
||||
|
|
@ -140,6 +152,7 @@ class OperatorConfig(NamedTuple):
|
|||
config: QuantizationConfig
|
||||
operators: List[OperatorPatternType]
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantizationAnnotation:
|
||||
"""How are input arguemnt or output should be quantized,
|
||||
|
|
@ -157,8 +170,8 @@ class QuantizationAnnotation:
|
|||
# whether the node is annotated or not
|
||||
_annotated: bool = False
|
||||
|
||||
class Quantizer(ABC):
|
||||
|
||||
class Quantizer(ABC):
|
||||
# annotate nodes in the graph with observer or fake quant constructors
|
||||
# to convey the desired way of quantization
|
||||
@abstractmethod
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization.pt2e.quantizer.quantizer import (
|
||||
from torch.ao.quantization.quantizer.quantizer import (
|
||||
QuantizationAnnotation,
|
||||
QuantizationConfig,
|
||||
QuantizationSpec,
|
||||
|
|
@ -15,6 +15,7 @@ __all__ = [
|
|||
"get_bias_qspec",
|
||||
]
|
||||
|
||||
|
||||
def get_input_act_qspec(quantization_config: QuantizationConfig):
|
||||
if quantization_config is None:
|
||||
return None
|
||||
|
|
@ -91,11 +92,11 @@ def _annotate_output_qspec(node: Node, qspec):
|
|||
|
||||
def _is_sym_size_node(node: Node):
|
||||
return (
|
||||
node.op == "call_function" and
|
||||
node.target == torch.ops.aten.sym_size.default or
|
||||
node.target == torch.ops.aten.sym_numel.default or
|
||||
node.target == torch.ops.aten.sym_numel or
|
||||
node.target == torch.ops.aten.sym_size
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.sym_size.default
|
||||
or node.target == torch.ops.aten.sym_numel.default
|
||||
or node.target == torch.ops.aten.sym_numel
|
||||
or node.target == torch.ops.aten.sym_size
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1,45 +1,45 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import copy
|
||||
import functools
|
||||
import itertools
|
||||
import operator
|
||||
from .quantizer import (
|
||||
OperatorConfig,
|
||||
OperatorPatternType,
|
||||
QuantizationConfig,
|
||||
QuantizationSpec,
|
||||
Quantizer,
|
||||
QuantizationAnnotation,
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.ao.quantization.observer import (
|
||||
HistogramObserver,
|
||||
PerChannelMinMaxObserver,
|
||||
PlaceholderObserver,
|
||||
)
|
||||
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
||||
from torch.ao.quantization.pt2e.quantizer.utils import (
|
||||
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
||||
from torch.ao.quantization.quantizer.utils import (
|
||||
get_bias_qspec,
|
||||
get_input_act_qspec,
|
||||
get_output_act_qspec,
|
||||
get_weight_qspec,
|
||||
get_bias_qspec,
|
||||
)
|
||||
from .xnnpack_quantizer import (
|
||||
_is_annotated,
|
||||
)
|
||||
from torch.ao.quantization.observer import (
|
||||
HistogramObserver,
|
||||
PlaceholderObserver,
|
||||
PerChannelMinMaxObserver,
|
||||
)
|
||||
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
||||
from typing import List, Dict, Optional, Set, Any
|
||||
from torch.fx import Node
|
||||
from torch.fx.passes.utils.source_matcher_utils import (
|
||||
get_source_partitions,
|
||||
SourcePartition,
|
||||
)
|
||||
from .quantizer import (
|
||||
OperatorConfig,
|
||||
OperatorPatternType,
|
||||
QuantizationAnnotation,
|
||||
QuantizationConfig,
|
||||
QuantizationSpec,
|
||||
Quantizer,
|
||||
)
|
||||
from .xnnpack_quantizer import _is_annotated
|
||||
|
||||
__all__ = [
|
||||
"X86InductorQuantizer",
|
||||
"get_default_x86_inductor_quantization_config",
|
||||
]
|
||||
|
||||
|
||||
def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
||||
# TODO: Add more supported operators here.
|
||||
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
||||
|
|
@ -71,7 +71,9 @@ def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
|||
|
||||
def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]:
|
||||
supported_config_and_operators: List[OperatorConfig] = []
|
||||
for quantization_config in [get_default_x86_inductor_quantization_config(), ]:
|
||||
for quantization_config in [
|
||||
get_default_x86_inductor_quantization_config(),
|
||||
]:
|
||||
ops = _supported_quantized_operators()
|
||||
for pattern_list in ops.values():
|
||||
supported_config_and_operators.append(
|
||||
|
|
@ -82,8 +84,9 @@ def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]:
|
|||
|
||||
@functools.lru_cache
|
||||
def get_default_x86_inductor_quantization_config():
|
||||
act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \
|
||||
act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
||||
HistogramObserver
|
||||
)
|
||||
|
||||
# Copy from x86 default qconfig from torch/ao/quantization/qconfig.py
|
||||
act_quantization_spec = QuantizationSpec(
|
||||
|
|
@ -92,10 +95,14 @@ def get_default_x86_inductor_quantization_config():
|
|||
quant_max=255, # reduce_range=False
|
||||
qscheme=torch.per_tensor_affine,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12),
|
||||
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
|
||||
eps=2**-12
|
||||
),
|
||||
)
|
||||
|
||||
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PerChannelMinMaxObserver
|
||||
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
||||
PerChannelMinMaxObserver
|
||||
)
|
||||
extra_args: Dict[str, Any] = {"eps": 2**-12}
|
||||
weight_quantization_spec = QuantizationSpec(
|
||||
dtype=torch.int8,
|
||||
|
|
@ -104,18 +111,21 @@ def get_default_x86_inductor_quantization_config():
|
|||
qscheme=torch.per_channel_symmetric,
|
||||
ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args),
|
||||
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
|
||||
**extra_args
|
||||
),
|
||||
)
|
||||
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
||||
PlaceholderObserver
|
||||
)
|
||||
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver
|
||||
bias_quantization_spec = QuantizationSpec(
|
||||
dtype=torch.float,
|
||||
observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
|
||||
dtype=torch.float, observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr
|
||||
)
|
||||
quantization_config = QuantizationConfig(
|
||||
act_quantization_spec,
|
||||
act_quantization_spec,
|
||||
weight_quantization_spec,
|
||||
bias_quantization_spec
|
||||
bias_quantization_spec,
|
||||
)
|
||||
return quantization_config
|
||||
|
||||
|
|
@ -170,8 +180,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
annotate_output: bool,
|
||||
quantization_config: QuantizationConfig,
|
||||
) -> None:
|
||||
""" Helper function to annotate the conv node
|
||||
"""
|
||||
"""Helper function to annotate the conv node"""
|
||||
input_qspec_map = {}
|
||||
input_node = conv_node.args[0]
|
||||
assert isinstance(input_node, Node)
|
||||
|
|
@ -187,20 +196,18 @@ class X86InductorQuantizer(Quantizer):
|
|||
input_qspec_map=input_qspec_map,
|
||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||
output_qspec=get_output_act_qspec(quantization_config),
|
||||
_annotated=True
|
||||
_annotated=True,
|
||||
)
|
||||
else:
|
||||
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map=input_qspec_map,
|
||||
_annotated=True
|
||||
input_qspec_map=input_qspec_map, _annotated=True
|
||||
)
|
||||
|
||||
def _get_output_nodes_of_partitions(
|
||||
self,
|
||||
partition_list: List[SourcePartition],
|
||||
) -> List[torch.fx.Node]:
|
||||
""" Helper function to get the output node list from partition list
|
||||
"""
|
||||
"""Helper function to get the output node list from partition list"""
|
||||
output_node_list = []
|
||||
for partition in partition_list:
|
||||
if len(partition.output_nodes) > 1:
|
||||
|
|
@ -209,7 +216,9 @@ class X86InductorQuantizer(Quantizer):
|
|||
assert isinstance(output_node, Node)
|
||||
output_node_list.append(output_node)
|
||||
if len(output_node_list) != len(partition_list):
|
||||
raise ValueError("length of output_node_list should equal to length of partition_list")
|
||||
raise ValueError(
|
||||
"length of output_node_list should equal to length of partition_list"
|
||||
)
|
||||
return output_node_list
|
||||
|
||||
def _get_input_idx_for_binary_node(
|
||||
|
|
@ -237,8 +246,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
return conv_gemm_node_idx, extra_input_node_idx
|
||||
|
||||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
""" just handling global spec for now
|
||||
"""
|
||||
"""just handling global spec for now"""
|
||||
model = self._annotate_for_static_quantization_config(model)
|
||||
return model
|
||||
|
||||
|
|
@ -268,28 +276,34 @@ class X86InductorQuantizer(Quantizer):
|
|||
conv_node, binary_node, unary_node = self._get_output_nodes_of_partitions(
|
||||
[conv_partition, binary_partition, unary_partition]
|
||||
)
|
||||
conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(conv_node, binary_node)
|
||||
conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(
|
||||
conv_node, binary_node
|
||||
)
|
||||
if (conv_node_idx is None) or (extra_input_node_idx is None):
|
||||
continue
|
||||
if conv_node != binary_node.args[conv_node_idx]:
|
||||
raise ValueError(f"{conv_node} doesn't match input of binary node")
|
||||
extra_input_node = binary_node.args[extra_input_node_idx]
|
||||
if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default:
|
||||
if (
|
||||
conv_node.op != "call_function"
|
||||
or conv_node.target != torch.ops.aten.convolution.default
|
||||
):
|
||||
# No conv node found to be fused with add
|
||||
continue
|
||||
if _is_annotated([unary_node, binary_node, conv_node]):
|
||||
continue
|
||||
self._annotate_conv_node_helper(conv_node, False, quantization_config)
|
||||
binary_node_input_qspec_map = {}
|
||||
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(quantization_config)
|
||||
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
|
||||
quantization_config
|
||||
)
|
||||
binary_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map=binary_node_input_qspec_map,
|
||||
_annotated=True
|
||||
input_qspec_map=binary_node_input_qspec_map, _annotated=True
|
||||
)
|
||||
unary_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
_annotated=True
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
def _annotate_conv2d_binary(
|
||||
|
|
@ -304,26 +318,33 @@ class X86InductorQuantizer(Quantizer):
|
|||
conv_node, binary_node = self._get_output_nodes_of_partitions(
|
||||
[conv_partition, binary_partition]
|
||||
)
|
||||
conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(conv_node, binary_node)
|
||||
conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node(
|
||||
conv_node, binary_node
|
||||
)
|
||||
if (conv_node_idx is None) or (extra_input_node_idx is None):
|
||||
continue
|
||||
if conv_node != binary_node.args[conv_node_idx]:
|
||||
raise ValueError(f"{conv_node} doesn't match input of binary node")
|
||||
extra_input_node = binary_node.args[extra_input_node_idx]
|
||||
assert isinstance(conv_node, Node)
|
||||
if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default:
|
||||
if (
|
||||
conv_node.op != "call_function"
|
||||
or conv_node.target != torch.ops.aten.convolution.default
|
||||
):
|
||||
# No conv node found to be fused with add
|
||||
continue
|
||||
if _is_annotated([binary_node, conv_node]):
|
||||
continue
|
||||
self._annotate_conv_node_helper(conv_node, False, quantization_config)
|
||||
binary_node_input_qspec_map = {}
|
||||
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(quantization_config)
|
||||
binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec(
|
||||
quantization_config
|
||||
)
|
||||
binary_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map=binary_node_input_qspec_map,
|
||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
_annotated=True
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
def _annotate_conv2d_unary(
|
||||
|
|
@ -337,7 +358,10 @@ class X86InductorQuantizer(Quantizer):
|
|||
conv_node, unary_node = self._get_output_nodes_of_partitions(
|
||||
[conv_partition, unary_partition]
|
||||
)
|
||||
if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default:
|
||||
if (
|
||||
conv_node.op != "call_function"
|
||||
or conv_node.target != torch.ops.aten.convolution.default
|
||||
):
|
||||
continue
|
||||
if _is_annotated([unary_node, conv_node]):
|
||||
continue
|
||||
|
|
@ -345,7 +369,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
unary_node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
# TODO<leslie> Remove the annotate of output when oneDNN qconv support fp32 out.
|
||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
_annotated=True
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
def _annotate_conv2d(
|
||||
|
|
@ -21,8 +21,9 @@ from torch.ao.quantization.observer import (
|
|||
)
|
||||
|
||||
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
||||
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
||||
|
||||
from torch.ao.quantization.pt2e.quantizer.utils import (
|
||||
from torch.ao.quantization.quantizer.utils import (
|
||||
_annotate_input_qspec_map,
|
||||
_annotate_output_qspec,
|
||||
_is_sym_size_node,
|
||||
|
|
@ -32,7 +33,6 @@ from torch.ao.quantization.pt2e.quantizer.utils import (
|
|||
get_output_act_qspec,
|
||||
get_weight_qspec,
|
||||
)
|
||||
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
|
||||
|
||||
from torch.fx import Node
|
||||
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
|
||||
Loading…
Reference in New Issue
Block a user