mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][pt2e] Rename _pt2e to pt2e (#104668)
Summary: X-link: https://github.com/pytorch/executorch/pull/3 att Test Plan: Imported from OSS Differential Revision: D47202807 Pull Request resolved: https://github.com/pytorch/pytorch/pull/104668 Approved by: https://github.com/andrewor14
This commit is contained in:
parent
a63f3f4335
commit
7b4d080496
|
|
@ -120,6 +120,13 @@ This module contains a few CustomConfig classes that's used in both eager mode a
|
|||
ConvertCustomConfig
|
||||
StandaloneModuleConfigEntry
|
||||
|
||||
torch.ao.quantization.pt2e (quantization in pytorch 2.0 export)
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: torch.ao.quantization.pt2e
|
||||
.. automodule:: torch.ao.quantization.pt2e.quantizer
|
||||
.. automodule:: torch.ao.quantization.pt2e.representation
|
||||
|
||||
torch (quantization related functions)
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
|||
|
|
@ -12,13 +12,13 @@ import weakref
|
|||
import torch
|
||||
|
||||
import torch._dynamo as torchdynamo
|
||||
import torch.ao.quantization._pt2e.quantizer.x86_inductor_quantizer as xiq
|
||||
import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
|
||||
from torch import nn
|
||||
from torch._inductor import config
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
from torch._inductor.utils import override_lowering, run_and_get_code
|
||||
from torch.ao.quantization._pt2e.quantizer import X86InductorQuantizer
|
||||
from torch.ao.quantization._quantize_pt2e import convert_pt2e, prepare_pt2e_quantizer
|
||||
from torch.ao.quantization.pt2e.quantizer import X86InductorQuantizer
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_quantization import (
|
||||
skipIfNoDynamoSupport,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
|
||||
from torch.ao.quantization._pt2e.graph_utils import (
|
||||
from torch.ao.quantization.pt2e.graph_utils import (
|
||||
find_sequential_partitions,
|
||||
get_equivalent_types,
|
||||
update_equivalent_types_dict,
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from torch.ao.quantization import (
|
|||
ObserverOrFakeQuantize,
|
||||
QConfigMapping,
|
||||
)
|
||||
from torch.ao.quantization._pt2e.quantizer import (
|
||||
from torch.ao.quantization.pt2e.quantizer import (
|
||||
ComposableQuantizer,
|
||||
DerivedQuantizationSpec,
|
||||
EmbeddingQuantizer,
|
||||
|
|
@ -27,10 +27,10 @@ from torch.ao.quantization._pt2e.quantizer import (
|
|||
Quantizer,
|
||||
SharedQuantizationSpec,
|
||||
)
|
||||
from torch.ao.quantization._pt2e.quantizer.composable_quantizer import ( # noqa: F811
|
||||
from torch.ao.quantization.pt2e.quantizer.composable_quantizer import ( # noqa: F811
|
||||
ComposableQuantizer,
|
||||
)
|
||||
from torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer import (
|
||||
from torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer import (
|
||||
get_symmetric_quantization_config,
|
||||
)
|
||||
from torch.ao.quantization._quantize_pt2e import (
|
||||
|
|
@ -1774,7 +1774,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
def forward(self, x, y):
|
||||
return x + y
|
||||
|
||||
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
|
||||
import torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer as qq
|
||||
|
||||
quantizer = QNNPackQuantizer()
|
||||
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
|
||||
|
|
@ -1799,7 +1799,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
def forward(self, x, y):
|
||||
return x + y
|
||||
|
||||
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
|
||||
import torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer as qq
|
||||
|
||||
quantizer = QNNPackQuantizer()
|
||||
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import copy
|
|||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
import torch.nn as nn
|
||||
from torch.ao.quantization._pt2e.quantizer import (
|
||||
from torch.ao.quantization.pt2e.quantizer import (
|
||||
X86InductorQuantizer,
|
||||
)
|
||||
from torch.ao.quantization._quantize_pt2e import (
|
||||
|
|
@ -19,7 +19,7 @@ from torch.testing._internal.common_quantization import (
|
|||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
from enum import Enum
|
||||
import itertools
|
||||
import torch.ao.quantization._pt2e.quantizer.x86_inductor_quantizer as xiq
|
||||
import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
|
||||
from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -139,10 +139,10 @@ FILENAME_ALLOWLIST |= {torch.utils._foreach_utils.__file__}
|
|||
# TODO: find a better way to express this path without having to import
|
||||
# `torch.ao.quantization._pt2e`, which interferes with memory profiling
|
||||
FILENAME_ALLOWLIST |= {
|
||||
_module_dir(torch) + "ao/quantization/_pt2e/qat_utils.py",
|
||||
_module_dir(torch) + "ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py",
|
||||
_module_dir(torch) + "ao/quantization/_pt2e/representation/rewrite.py",
|
||||
_module_dir(torch) + "ao/quantization/_pt2e/utils.py",
|
||||
_module_dir(torch) + "ao/quantization/pt2e/qat_utils.py",
|
||||
_module_dir(torch) + "ao/quantization/pt2e/quantizer/qnnpack_quantizer.py",
|
||||
_module_dir(torch) + "ao/quantization/pt2e/representation/rewrite.py",
|
||||
_module_dir(torch) + "ao/quantization/pt2e/utils.py",
|
||||
}
|
||||
|
||||
# TODO (zhxchen17) Make exportdb importable here.
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from torch._functorch.compile_utils import fx_graph_cse
|
|||
from torch._inductor.compile_fx import fake_tensor_prop
|
||||
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
from torch.ao.quantization._pt2e.utils import _fuse_conv_bn_
|
||||
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from . import config
|
||||
from .decomposition import select_decomp_table
|
||||
|
|
|
|||
|
|
@ -1,23 +1,23 @@
|
|||
from torch.fx import GraphModule
|
||||
|
||||
from ._pt2e.prepare import prepare
|
||||
from ._pt2e._propagate_annotation import propagate_annotation
|
||||
from ._pt2e.qat_utils import (
|
||||
from .pt2e.prepare import prepare
|
||||
from .pt2e._propagate_annotation import propagate_annotation
|
||||
from .pt2e.qat_utils import (
|
||||
_fuse_conv_bn_qat,
|
||||
_fold_conv_bn_qat,
|
||||
)
|
||||
from ._pt2e.utils import (
|
||||
from .pt2e.utils import (
|
||||
_get_node_name_to_scope,
|
||||
_fuse_conv_bn_,
|
||||
_rearrange_weight_observer_for_decomposed_linear,
|
||||
_replace_dropout_for_eval,
|
||||
)
|
||||
from ._pt2e.representation import reference_representation_rewrite
|
||||
from .pt2e.representation import reference_representation_rewrite
|
||||
from .fx.prepare import prepare as fx_prepare
|
||||
from .quantize_fx import _convert_to_reference_decomposed_fx
|
||||
from torch.ao.quantization import QConfigMapping
|
||||
# TODO: move quantizer to torch.ao.quantization
|
||||
from torch.ao.quantization._pt2e.quantizer import ( # noqa: F401
|
||||
from torch.ao.quantization.pt2e.quantizer import ( # noqa: F401
|
||||
OperatorConfig,
|
||||
OperatorPatternType,
|
||||
QuantizationConfig,
|
||||
|
|
@ -32,13 +32,13 @@ from torch.ao.quantization._pt2e.quantizer import ( # noqa: F401
|
|||
EmbeddingQuantizer,
|
||||
ComposableQuantizer,
|
||||
)
|
||||
from torch.ao.quantization._pt2e.quantizer.utils import ( # noqa: F401
|
||||
from torch.ao.quantization.pt2e.quantizer.utils import ( # noqa: F401
|
||||
get_bias_qspec,
|
||||
get_input_act_qspec,
|
||||
get_output_act_qspec,
|
||||
get_weight_qspec,
|
||||
)
|
||||
from torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer import ( # noqa: F401
|
||||
from torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer import ( # noqa: F401
|
||||
get_symmetric_quantization_config,
|
||||
)
|
||||
from torch.ao.quantization.backend_config import BackendConfig
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ from .custom_config import (
|
|||
PrepareCustomConfig,
|
||||
StandaloneModuleConfigEntry,
|
||||
)
|
||||
from torch.ao.quantization._pt2e.quantizer import (
|
||||
from torch.ao.quantization.pt2e.quantizer import (
|
||||
EdgeOrNode,
|
||||
QuantizationSpec,
|
||||
FixedQParamsQuantizationSpec,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization._pt2e.quantizer import (
|
||||
from torch.ao.quantization.pt2e.quantizer import (
|
||||
QuantizationAnnotation,
|
||||
SharedQuantizationSpec,
|
||||
)
|
||||
|
|
@ -10,6 +10,12 @@ from torch.fx.passes.utils.source_matcher_utils import (
|
|||
SourcePartition,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"find_sequential_partitions",
|
||||
"get_equivalent_types",
|
||||
"update_equivalent_types_dict",
|
||||
]
|
||||
|
||||
_EQUIVALENT_TYPES: List[Set] = [
|
||||
{torch.nn.Conv2d, torch.nn.functional.conv2d},
|
||||
{torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d},
|
||||
|
|
@ -19,7 +19,7 @@ from torch.ao.quantization import QConfigMapping
|
|||
from torch.ao.quantization.qconfig import QConfigAny
|
||||
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
|
||||
from typing import Dict, Tuple, Union, Any
|
||||
from torch.ao.quantization._pt2e.quantizer import (
|
||||
from torch.ao.quantization.pt2e.quantizer import (
|
||||
QuantizationAnnotation,
|
||||
EdgeOrNode,
|
||||
)
|
||||
|
|
@ -15,8 +15,8 @@ from .quantizer import (
|
|||
QuantizationSpecBase,
|
||||
)
|
||||
from .utils import (
|
||||
_fold_bn_weights_into_conv_node,
|
||||
_get_aten_graph_module,
|
||||
fold_bn_weights_into_conv_node,
|
||||
get_aten_graph_module,
|
||||
)
|
||||
|
||||
# Example inputs for `_conv2d_bn_pattern`, `_qat_conv2d_bn_pattern`, and `_qat_conv2d_bn_pattern_no_bias`
|
||||
|
|
@ -496,7 +496,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
|||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
example_inputs = _conv2d_bn_pattern_example_inputs
|
||||
match_pattern = _get_aten_graph_module(_conv2d_bn_pattern, example_inputs)
|
||||
match_pattern = get_aten_graph_module(_conv2d_bn_pattern, example_inputs)
|
||||
|
||||
# Step (1): Replace patterns with conv bias
|
||||
#
|
||||
|
|
@ -504,7 +504,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
|||
# the replacement patterns for these two cases are substantially different.
|
||||
# TODO: use the public replace_pattern API once it also returns replacement nodes
|
||||
|
||||
replacement_pattern_with_conv_bias = _get_aten_graph_module(
|
||||
replacement_pattern_with_conv_bias = get_aten_graph_module(
|
||||
_qat_conv2d_bn_pattern,
|
||||
example_inputs,
|
||||
)
|
||||
|
|
@ -519,7 +519,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
|||
|
||||
# Step (2): Replace patterns without conv bias
|
||||
|
||||
replacement_pattern_no_conv_bias = _get_aten_graph_module(
|
||||
replacement_pattern_no_conv_bias = get_aten_graph_module(
|
||||
_qat_conv2d_bn_pattern_no_conv_bias,
|
||||
example_inputs,
|
||||
)
|
||||
|
|
@ -652,11 +652,11 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
|||
match_pattern = _get_quantized_qat_conv2d_bn_pattern(
|
||||
is_per_channel, has_relu, has_bias, relu_is_inplace,
|
||||
)
|
||||
match_pattern = _get_aten_graph_module(match_pattern, example_inputs, **kwargs)
|
||||
match_pattern = get_aten_graph_module(match_pattern, example_inputs, **kwargs)
|
||||
replacement_pattern = _get_folded_quantized_qat_conv2d_bn_pattern(
|
||||
is_per_channel, has_relu, has_bias, relu_is_inplace,
|
||||
)
|
||||
replacement_pattern = _get_aten_graph_module(replacement_pattern, example_inputs, **kwargs)
|
||||
replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, **kwargs)
|
||||
replacements.extend(
|
||||
replace_pattern_with_filters(
|
||||
m,
|
||||
|
|
@ -720,7 +720,7 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
|||
)
|
||||
|
||||
# fold bn weights into conv
|
||||
_fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
|
||||
fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
|
||||
|
||||
# Copy over literal args for conv
|
||||
for original_node in _filter_nodes_map(r.nodes_map).values():
|
||||
|
|
@ -5,7 +5,7 @@ from typing import List, Set
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.ao.quantization._pt2e.quantizer.quantizer import (
|
||||
from torch.ao.quantization.pt2e.quantizer.quantizer import (
|
||||
OperatorConfig,
|
||||
OperatorPatternType,
|
||||
QuantizationAnnotation,
|
||||
|
|
@ -15,6 +15,10 @@ from torch.ao.quantization._pt2e.quantizer.quantizer import (
|
|||
)
|
||||
from torch.ao.quantization.observer import PerChannelMinMaxObserver
|
||||
|
||||
__all__ = [
|
||||
"get_embedding_operators_config",
|
||||
"EmbeddingQuantizer",
|
||||
]
|
||||
|
||||
def get_embedding_operators_config() -> OperatorConfig:
|
||||
weight_quantization_spec = QuantizationSpec(
|
||||
|
|
@ -11,9 +11,9 @@ import torch
|
|||
import torch._dynamo as torchdynamo
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.ao.quantization._pt2e.graph_utils import find_sequential_partitions
|
||||
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
||||
|
||||
from torch.ao.quantization._pt2e.quantizer.utils import (
|
||||
from torch.ao.quantization.pt2e.quantizer.utils import (
|
||||
_annotate_input_qspec_map,
|
||||
_annotate_output_qspec,
|
||||
_is_sym_size_node,
|
||||
|
|
@ -84,7 +84,7 @@ def _get_linear_patterns(input_size: List[int]):
|
|||
return [pattern_w_bias, pattern_wo_bias]
|
||||
|
||||
|
||||
def supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
||||
def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
||||
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
||||
# Both conv and linear should be able to handle relu + hardtanh fusion since
|
||||
# those are clamp ops
|
||||
|
|
@ -107,7 +107,7 @@ def supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternT
|
|||
return copy.deepcopy(supported_operators)
|
||||
|
||||
|
||||
def get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
|
||||
def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
|
||||
supported_config_and_operators: List[OperatorConfig] = []
|
||||
for quantization_config in [
|
||||
get_symmetric_quantization_config(),
|
||||
|
|
@ -115,7 +115,7 @@ def get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
|
|||
get_symmetric_quantization_config(is_per_channel=True),
|
||||
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
|
||||
]:
|
||||
ops = supported_symmetric_quantized_operators()
|
||||
ops = _supported_symmetric_quantized_operators()
|
||||
for op_string, pattern_list in ops.items():
|
||||
supported_config_and_operators.append(
|
||||
OperatorConfig(quantization_config, pattern_list)
|
||||
|
|
@ -205,8 +205,8 @@ def get_symmetric_quantization_config(
|
|||
return quantization_config
|
||||
|
||||
|
||||
def get_supported_config_and_operators() -> List[OperatorConfig]:
|
||||
return get_supported_symmetric_config_and_operators()
|
||||
def _get_supported_config_and_operators() -> List[OperatorConfig]:
|
||||
return _get_supported_symmetric_config_and_operators()
|
||||
|
||||
|
||||
def _is_annotated(nodes: List[Node]):
|
||||
|
|
@ -225,7 +225,7 @@ def _is_annotated(nodes: List[Node]):
|
|||
|
||||
|
||||
class QNNPackQuantizer(Quantizer):
|
||||
supported_config_and_operators = get_supported_config_and_operators()
|
||||
supported_config_and_operators = _get_supported_config_and_operators()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -13,11 +13,13 @@ __all__ = [
|
|||
"QuantizationSpecBase",
|
||||
"QuantizationSpec",
|
||||
"FixedQParamsQuantizationSpec",
|
||||
"EdgeOrNode",
|
||||
"SharedQuantizationSpec",
|
||||
"DerivedQuantizationSpec",
|
||||
"QuantizationAnnotation",
|
||||
"QuantizationConfig",
|
||||
"OperatorPatternType",
|
||||
"OperatorConfig",
|
||||
]
|
||||
|
||||
# TODO: maybe remove torch.float32
|
||||
|
|
@ -86,17 +88,19 @@ class FixedQParamsQuantizationSpec(QuantizationSpecBase):
|
|||
quant_max: Optional[int] = None
|
||||
qscheme: Optional[torch.qscheme] = None
|
||||
|
||||
"""
|
||||
The way we refer to other points of quantization in the graph will be either
|
||||
an input edge or an output value
|
||||
input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]
|
||||
output value is an fx Node
|
||||
"""
|
||||
EdgeOrNode = Union[Tuple[Node, Node], Node]
|
||||
EdgeOrNode.__module__ = "torch.ao.quantization.pt2e.quantizer.quantizer"
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class SharedQuantizationSpec(QuantizationSpecBase):
|
||||
"""
|
||||
Quantization spec for the Tensors whose quantization parameters are shared with other Tensors
|
||||
|
||||
The way we refer to other points of quantization in the graph will be either
|
||||
an input edge or an output value
|
||||
input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]
|
||||
output value is an fx Node
|
||||
"""
|
||||
edge_or_node: EdgeOrNode
|
||||
|
||||
|
|
@ -122,6 +126,7 @@ class QuantizationConfig:
|
|||
is_qat: bool = False
|
||||
|
||||
OperatorPatternType = List[Callable]
|
||||
OperatorPatternType.__module__ = "torch.ao.quantization.pt2e.quantizer.quantizer"
|
||||
|
||||
OperatorConfig = NamedTuple(
|
||||
"OperatorConfig",
|
||||
|
|
@ -1,13 +1,19 @@
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization._pt2e.quantizer.quantizer import (
|
||||
from torch.ao.quantization.pt2e.quantizer.quantizer import (
|
||||
QuantizationAnnotation,
|
||||
QuantizationConfig,
|
||||
QuantizationSpec,
|
||||
)
|
||||
from torch.fx import Node
|
||||
|
||||
__all__ = [
|
||||
"get_input_act_qspec",
|
||||
"get_output_act_qspec",
|
||||
"get_weight_qspec",
|
||||
"get_bias_qspec",
|
||||
]
|
||||
|
||||
def get_input_act_qspec(quantization_config: QuantizationConfig):
|
||||
if quantization_config is None:
|
||||
|
|
@ -12,8 +12,8 @@ from .quantizer import (
|
|||
Quantizer,
|
||||
QuantizationAnnotation,
|
||||
)
|
||||
from torch.ao.quantization._pt2e.graph_utils import find_sequential_partitions
|
||||
from torch.ao.quantization._pt2e.quantizer.utils import (
|
||||
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
||||
from torch.ao.quantization.pt2e.quantizer.utils import (
|
||||
get_input_act_qspec,
|
||||
get_output_act_qspec,
|
||||
get_weight_qspec,
|
||||
|
|
@ -40,7 +40,7 @@ __all__ = [
|
|||
"get_default_x86_inductor_quantization_config",
|
||||
]
|
||||
|
||||
def supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
||||
def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
||||
# TODO: Add more supported operators here.
|
||||
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
||||
"conv2d": [
|
||||
|
|
@ -69,10 +69,10 @@ def supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
|||
return copy.deepcopy(supported_operators)
|
||||
|
||||
|
||||
def get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]:
|
||||
def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]:
|
||||
supported_config_and_operators: List[OperatorConfig] = []
|
||||
for quantization_config in [get_default_x86_inductor_quantization_config(), ]:
|
||||
ops = supported_quantized_operators()
|
||||
ops = _supported_quantized_operators()
|
||||
for op_string, pattern_list in ops.items():
|
||||
supported_config_and_operators.append(
|
||||
OperatorConfig(quantization_config, pattern_list)
|
||||
|
|
@ -120,12 +120,12 @@ def get_default_x86_inductor_quantization_config():
|
|||
return quantization_config
|
||||
|
||||
|
||||
def get_supported_config_and_operators() -> List[OperatorConfig]:
|
||||
return get_supported_x86_inductor_config_and_operators()
|
||||
def _get_supported_config_and_operators() -> List[OperatorConfig]:
|
||||
return _get_supported_x86_inductor_config_and_operators()
|
||||
|
||||
|
||||
class X86InductorQuantizer(Quantizer):
|
||||
supported_config_and_operators = get_supported_config_and_operators()
|
||||
supported_config_and_operators = _get_supported_config_and_operators()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from ..utils import _get_aten_graph_module
|
||||
from ..utils import _remove_tensor_overload_for_qdq_ops
|
||||
from ..utils import get_aten_graph_module
|
||||
from ..utils import remove_tensor_overload_for_qdq_ops
|
||||
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
||||
from torch.fx.subgraph_rewriter import replace_pattern
|
||||
|
||||
|
|
@ -117,11 +117,11 @@ _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS = [
|
|||
]
|
||||
|
||||
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
||||
_remove_tensor_overload_for_qdq_ops(model)
|
||||
remove_tensor_overload_for_qdq_ops(model)
|
||||
for example_inputs, pattern, replacement in _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS:
|
||||
pattern = _get_aten_graph_module(pattern, example_inputs)
|
||||
_remove_tensor_overload_for_qdq_ops(pattern)
|
||||
replacement = _get_aten_graph_module(replacement, example_inputs)
|
||||
_remove_tensor_overload_for_qdq_ops(replacement)
|
||||
pattern = get_aten_graph_module(pattern, example_inputs)
|
||||
remove_tensor_overload_for_qdq_ops(pattern)
|
||||
replacement = get_aten_graph_module(replacement, example_inputs)
|
||||
remove_tensor_overload_for_qdq_ops(replacement)
|
||||
matches = replace_pattern(model, pattern, replacement)
|
||||
return model
|
||||
|
|
@ -15,6 +15,11 @@ import copy
|
|||
import operator
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
__all__ = [
|
||||
"fold_bn_weights_into_conv_node",
|
||||
"get_aten_graph_module",
|
||||
"remove_tensor_overload_for_qdq_ops",
|
||||
]
|
||||
|
||||
def _get_tensor_constant_from_node(node, m):
|
||||
if node is None:
|
||||
|
|
@ -33,7 +38,7 @@ def _get_all_arguments(orig_args, orig_kwargs, args_schema):
|
|||
all_args.append(schema.default_value)
|
||||
return all_args
|
||||
|
||||
def _fold_bn_weights_into_conv_node(
|
||||
def fold_bn_weights_into_conv_node(
|
||||
conv_node: Node,
|
||||
conv_weight_node: Node,
|
||||
conv_bias_node: Optional[Node],
|
||||
|
|
@ -111,7 +116,7 @@ def _fuse_conv_bn_(m: GraphModule) -> None:
|
|||
conv_node = n
|
||||
conv_weight_node = conv_node.args[1]
|
||||
conv_bias_node = conv_node.args[2]
|
||||
_fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m)
|
||||
fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m)
|
||||
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
|
|
@ -177,7 +182,7 @@ def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
|
|||
node_name_to_scope[n.name] = current_scope
|
||||
return node_name_to_scope
|
||||
|
||||
def _get_aten_graph_module(
|
||||
def get_aten_graph_module(
|
||||
pattern: Callable,
|
||||
example_inputs: Tuple[Any, ...],
|
||||
**kwargs,
|
||||
|
|
@ -198,7 +203,7 @@ def _get_aten_graph_module(
|
|||
aten_pattern.recompile()
|
||||
return aten_pattern
|
||||
|
||||
def _remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
|
||||
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
|
||||
""" Remove .tensor overload for quantize/dequantize ops so that we can
|
||||
use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
|
||||
"""
|
||||
|
|
@ -258,8 +263,8 @@ def _replace_dropout_for_eval(m: GraphModule):
|
|||
return F.dropout(x, p=0.5, training=False)
|
||||
|
||||
example_inputs = (torch.randn(1),)
|
||||
match_pattern = _get_aten_graph_module(dropout_train, example_inputs)
|
||||
replacement_pattern = _get_aten_graph_module(dropout_eval, example_inputs)
|
||||
match_pattern = get_aten_graph_module(dropout_train, example_inputs)
|
||||
replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
|
||||
|
||||
# Note: The match pattern looks like:
|
||||
#
|
||||
Loading…
Reference in New Issue
Block a user