[quant][pt2e] Rename _pt2e to pt2e (#104668)

Summary:
X-link: https://github.com/pytorch/executorch/pull/3

att

Test Plan: Imported from OSS

Differential Revision: D47202807

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104668
Approved by: https://github.com/andrewor14
This commit is contained in:
Jerry Zhang 2023-07-15 06:34:17 +00:00 committed by PyTorch MergeBot
parent a63f3f4335
commit 7b4d080496
24 changed files with 103 additions and 70 deletions

View File

@ -120,6 +120,13 @@ This module contains a few CustomConfig classes that's used in both eager mode a
ConvertCustomConfig
StandaloneModuleConfigEntry
torch.ao.quantization.pt2e (quantization in pytorch 2.0 export)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: torch.ao.quantization.pt2e
.. automodule:: torch.ao.quantization.pt2e.quantizer
.. automodule:: torch.ao.quantization.pt2e.representation
torch (quantization related functions)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -12,13 +12,13 @@ import weakref
import torch
import torch._dynamo as torchdynamo
import torch.ao.quantization._pt2e.quantizer.x86_inductor_quantizer as xiq
import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
from torch import nn
from torch._inductor import config
from torch._inductor.compile_fx import compile_fx
from torch._inductor.utils import override_lowering, run_and_get_code
from torch.ao.quantization._pt2e.quantizer import X86InductorQuantizer
from torch.ao.quantization._quantize_pt2e import convert_pt2e, prepare_pt2e_quantizer
from torch.ao.quantization.pt2e.quantizer import X86InductorQuantizer
from torch.testing import FileCheck
from torch.testing._internal.common_quantization import (
skipIfNoDynamoSupport,

View File

@ -5,7 +5,7 @@ import unittest
import torch
import torch._dynamo as torchdynamo
from torch.ao.quantization._pt2e.graph_utils import (
from torch.ao.quantization.pt2e.graph_utils import (
find_sequential_partitions,
get_equivalent_types,
update_equivalent_types_dict,

View File

@ -15,7 +15,7 @@ from torch.ao.quantization import (
ObserverOrFakeQuantize,
QConfigMapping,
)
from torch.ao.quantization._pt2e.quantizer import (
from torch.ao.quantization.pt2e.quantizer import (
ComposableQuantizer,
DerivedQuantizationSpec,
EmbeddingQuantizer,
@ -27,10 +27,10 @@ from torch.ao.quantization._pt2e.quantizer import (
Quantizer,
SharedQuantizationSpec,
)
from torch.ao.quantization._pt2e.quantizer.composable_quantizer import ( # noqa: F811
from torch.ao.quantization.pt2e.quantizer.composable_quantizer import ( # noqa: F811
ComposableQuantizer,
)
from torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer import (
from torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer import (
get_symmetric_quantization_config,
)
from torch.ao.quantization._quantize_pt2e import (
@ -1774,7 +1774,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
def forward(self, x, y):
return x + y
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
import torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
@ -1799,7 +1799,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
def forward(self, x, y):
return x + y
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
import torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)

View File

@ -3,7 +3,7 @@ import copy
import torch
import torch._dynamo as torchdynamo
import torch.nn as nn
from torch.ao.quantization._pt2e.quantizer import (
from torch.ao.quantization.pt2e.quantizer import (
X86InductorQuantizer,
)
from torch.ao.quantization._quantize_pt2e import (
@ -19,7 +19,7 @@ from torch.testing._internal.common_quantization import (
from torch.testing._internal.common_quantized import override_quantized_engine
from enum import Enum
import itertools
import torch.ao.quantization._pt2e.quantizer.x86_inductor_quantizer as xiq
import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle

View File

@ -139,10 +139,10 @@ FILENAME_ALLOWLIST |= {torch.utils._foreach_utils.__file__}
# TODO: find a better way to express this path without having to import
# `torch.ao.quantization._pt2e`, which interferes with memory profiling
FILENAME_ALLOWLIST |= {
_module_dir(torch) + "ao/quantization/_pt2e/qat_utils.py",
_module_dir(torch) + "ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py",
_module_dir(torch) + "ao/quantization/_pt2e/representation/rewrite.py",
_module_dir(torch) + "ao/quantization/_pt2e/utils.py",
_module_dir(torch) + "ao/quantization/pt2e/qat_utils.py",
_module_dir(torch) + "ao/quantization/pt2e/quantizer/qnnpack_quantizer.py",
_module_dir(torch) + "ao/quantization/pt2e/representation/rewrite.py",
_module_dir(torch) + "ao/quantization/pt2e/utils.py",
}
# TODO (zhxchen17) Make exportdb importable here.

View File

@ -13,7 +13,7 @@ from torch._functorch.compile_utils import fx_graph_cse
from torch._inductor.compile_fx import fake_tensor_prop
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
from torch._inductor.fx_passes.post_grad import view_to_reshape
from torch.ao.quantization._pt2e.utils import _fuse_conv_bn_
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
from torch.fx.experimental.proxy_tensor import make_fx
from . import config
from .decomposition import select_decomp_table

View File

@ -1,23 +1,23 @@
from torch.fx import GraphModule
from ._pt2e.prepare import prepare
from ._pt2e._propagate_annotation import propagate_annotation
from ._pt2e.qat_utils import (
from .pt2e.prepare import prepare
from .pt2e._propagate_annotation import propagate_annotation
from .pt2e.qat_utils import (
_fuse_conv_bn_qat,
_fold_conv_bn_qat,
)
from ._pt2e.utils import (
from .pt2e.utils import (
_get_node_name_to_scope,
_fuse_conv_bn_,
_rearrange_weight_observer_for_decomposed_linear,
_replace_dropout_for_eval,
)
from ._pt2e.representation import reference_representation_rewrite
from .pt2e.representation import reference_representation_rewrite
from .fx.prepare import prepare as fx_prepare
from .quantize_fx import _convert_to_reference_decomposed_fx
from torch.ao.quantization import QConfigMapping
# TODO: move quantizer to torch.ao.quantization
from torch.ao.quantization._pt2e.quantizer import ( # noqa: F401
from torch.ao.quantization.pt2e.quantizer import ( # noqa: F401
OperatorConfig,
OperatorPatternType,
QuantizationConfig,
@ -32,13 +32,13 @@ from torch.ao.quantization._pt2e.quantizer import ( # noqa: F401
EmbeddingQuantizer,
ComposableQuantizer,
)
from torch.ao.quantization._pt2e.quantizer.utils import ( # noqa: F401
from torch.ao.quantization.pt2e.quantizer.utils import ( # noqa: F401
get_bias_qspec,
get_input_act_qspec,
get_output_act_qspec,
get_weight_qspec,
)
from torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer import ( # noqa: F401
from torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer import ( # noqa: F401
get_symmetric_quantization_config,
)
from torch.ao.quantization.backend_config import BackendConfig

View File

@ -106,7 +106,7 @@ from .custom_config import (
PrepareCustomConfig,
StandaloneModuleConfigEntry,
)
from torch.ao.quantization._pt2e.quantizer import (
from torch.ao.quantization.pt2e.quantizer import (
EdgeOrNode,
QuantizationSpec,
FixedQParamsQuantizationSpec,

View File

@ -1,7 +1,7 @@
from typing import Callable
import torch
from torch.ao.quantization._pt2e.quantizer import (
from torch.ao.quantization.pt2e.quantizer import (
QuantizationAnnotation,
SharedQuantizationSpec,
)

View File

@ -10,6 +10,12 @@ from torch.fx.passes.utils.source_matcher_utils import (
SourcePartition,
)
__all__ = [
"find_sequential_partitions",
"get_equivalent_types",
"update_equivalent_types_dict",
]
_EQUIVALENT_TYPES: List[Set] = [
{torch.nn.Conv2d, torch.nn.functional.conv2d},
{torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d},

View File

@ -19,7 +19,7 @@ from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.qconfig import QConfigAny
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
from typing import Dict, Tuple, Union, Any
from torch.ao.quantization._pt2e.quantizer import (
from torch.ao.quantization.pt2e.quantizer import (
QuantizationAnnotation,
EdgeOrNode,
)

View File

@ -15,8 +15,8 @@ from .quantizer import (
QuantizationSpecBase,
)
from .utils import (
_fold_bn_weights_into_conv_node,
_get_aten_graph_module,
fold_bn_weights_into_conv_node,
get_aten_graph_module,
)
# Example inputs for `_conv2d_bn_pattern`, `_qat_conv2d_bn_pattern`, and `_qat_conv2d_bn_pattern_no_bias`
@ -496,7 +496,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
m.graph.eliminate_dead_code()
m.recompile()
example_inputs = _conv2d_bn_pattern_example_inputs
match_pattern = _get_aten_graph_module(_conv2d_bn_pattern, example_inputs)
match_pattern = get_aten_graph_module(_conv2d_bn_pattern, example_inputs)
# Step (1): Replace patterns with conv bias
#
@ -504,7 +504,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
# the replacement patterns for these two cases are substantially different.
# TODO: use the public replace_pattern API once it also returns replacement nodes
replacement_pattern_with_conv_bias = _get_aten_graph_module(
replacement_pattern_with_conv_bias = get_aten_graph_module(
_qat_conv2d_bn_pattern,
example_inputs,
)
@ -519,7 +519,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
# Step (2): Replace patterns without conv bias
replacement_pattern_no_conv_bias = _get_aten_graph_module(
replacement_pattern_no_conv_bias = get_aten_graph_module(
_qat_conv2d_bn_pattern_no_conv_bias,
example_inputs,
)
@ -652,11 +652,11 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
match_pattern = _get_quantized_qat_conv2d_bn_pattern(
is_per_channel, has_relu, has_bias, relu_is_inplace,
)
match_pattern = _get_aten_graph_module(match_pattern, example_inputs, **kwargs)
match_pattern = get_aten_graph_module(match_pattern, example_inputs, **kwargs)
replacement_pattern = _get_folded_quantized_qat_conv2d_bn_pattern(
is_per_channel, has_relu, has_bias, relu_is_inplace,
)
replacement_pattern = _get_aten_graph_module(replacement_pattern, example_inputs, **kwargs)
replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, **kwargs)
replacements.extend(
replace_pattern_with_filters(
m,
@ -720,7 +720,7 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
)
# fold bn weights into conv
_fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
# Copy over literal args for conv
for original_node in _filter_nodes_map(r.nodes_map).values():

View File

@ -5,7 +5,7 @@ from typing import List, Set
import torch
import torch.nn.functional as F
from torch.ao.quantization._pt2e.quantizer.quantizer import (
from torch.ao.quantization.pt2e.quantizer.quantizer import (
OperatorConfig,
OperatorPatternType,
QuantizationAnnotation,
@ -15,6 +15,10 @@ from torch.ao.quantization._pt2e.quantizer.quantizer import (
)
from torch.ao.quantization.observer import PerChannelMinMaxObserver
__all__ = [
"get_embedding_operators_config",
"EmbeddingQuantizer",
]
def get_embedding_operators_config() -> OperatorConfig:
weight_quantization_spec = QuantizationSpec(

View File

@ -11,9 +11,9 @@ import torch
import torch._dynamo as torchdynamo
import torch.nn.functional as F
from torch.ao.quantization._pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization._pt2e.quantizer.utils import (
from torch.ao.quantization.pt2e.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
_is_sym_size_node,
@ -84,7 +84,7 @@ def _get_linear_patterns(input_size: List[int]):
return [pattern_w_bias, pattern_wo_bias]
def supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
supported_operators: Dict[str, List[OperatorPatternType]] = {
# Both conv and linear should be able to handle relu + hardtanh fusion since
# those are clamp ops
@ -107,7 +107,7 @@ def supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternT
return copy.deepcopy(supported_operators)
def get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
supported_config_and_operators: List[OperatorConfig] = []
for quantization_config in [
get_symmetric_quantization_config(),
@ -115,7 +115,7 @@ def get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
get_symmetric_quantization_config(is_per_channel=True),
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
]:
ops = supported_symmetric_quantized_operators()
ops = _supported_symmetric_quantized_operators()
for op_string, pattern_list in ops.items():
supported_config_and_operators.append(
OperatorConfig(quantization_config, pattern_list)
@ -205,8 +205,8 @@ def get_symmetric_quantization_config(
return quantization_config
def get_supported_config_and_operators() -> List[OperatorConfig]:
return get_supported_symmetric_config_and_operators()
def _get_supported_config_and_operators() -> List[OperatorConfig]:
return _get_supported_symmetric_config_and_operators()
def _is_annotated(nodes: List[Node]):
@ -225,7 +225,7 @@ def _is_annotated(nodes: List[Node]):
class QNNPackQuantizer(Quantizer):
supported_config_and_operators = get_supported_config_and_operators()
supported_config_and_operators = _get_supported_config_and_operators()
def __init__(self):
super().__init__()

View File

@ -13,11 +13,13 @@ __all__ = [
"QuantizationSpecBase",
"QuantizationSpec",
"FixedQParamsQuantizationSpec",
"EdgeOrNode",
"SharedQuantizationSpec",
"DerivedQuantizationSpec",
"QuantizationAnnotation",
"QuantizationConfig",
"OperatorPatternType",
"OperatorConfig",
]
# TODO: maybe remove torch.float32
@ -86,17 +88,19 @@ class FixedQParamsQuantizationSpec(QuantizationSpecBase):
quant_max: Optional[int] = None
qscheme: Optional[torch.qscheme] = None
"""
The way we refer to other points of quantization in the graph will be either
an input edge or an output value
input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]
output value is an fx Node
"""
EdgeOrNode = Union[Tuple[Node, Node], Node]
EdgeOrNode.__module__ = "torch.ao.quantization.pt2e.quantizer.quantizer"
@dataclass(eq=True, frozen=True)
class SharedQuantizationSpec(QuantizationSpecBase):
"""
Quantization spec for the Tensors whose quantization parameters are shared with other Tensors
The way we refer to other points of quantization in the graph will be either
an input edge or an output value
input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]
output value is an fx Node
"""
edge_or_node: EdgeOrNode
@ -122,6 +126,7 @@ class QuantizationConfig:
is_qat: bool = False
OperatorPatternType = List[Callable]
OperatorPatternType.__module__ = "torch.ao.quantization.pt2e.quantizer.quantizer"
OperatorConfig = NamedTuple(
"OperatorConfig",

View File

@ -1,13 +1,19 @@
from typing import List
import torch
from torch.ao.quantization._pt2e.quantizer.quantizer import (
from torch.ao.quantization.pt2e.quantizer.quantizer import (
QuantizationAnnotation,
QuantizationConfig,
QuantizationSpec,
)
from torch.fx import Node
__all__ = [
"get_input_act_qspec",
"get_output_act_qspec",
"get_weight_qspec",
"get_bias_qspec",
]
def get_input_act_qspec(quantization_config: QuantizationConfig):
if quantization_config is None:

View File

@ -12,8 +12,8 @@ from .quantizer import (
Quantizer,
QuantizationAnnotation,
)
from torch.ao.quantization._pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization._pt2e.quantizer.utils import (
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.pt2e.quantizer.utils import (
get_input_act_qspec,
get_output_act_qspec,
get_weight_qspec,
@ -40,7 +40,7 @@ __all__ = [
"get_default_x86_inductor_quantization_config",
]
def supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
# TODO: Add more supported operators here.
supported_operators: Dict[str, List[OperatorPatternType]] = {
"conv2d": [
@ -69,10 +69,10 @@ def supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
return copy.deepcopy(supported_operators)
def get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]:
def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]:
supported_config_and_operators: List[OperatorConfig] = []
for quantization_config in [get_default_x86_inductor_quantization_config(), ]:
ops = supported_quantized_operators()
ops = _supported_quantized_operators()
for op_string, pattern_list in ops.items():
supported_config_and_operators.append(
OperatorConfig(quantization_config, pattern_list)
@ -120,12 +120,12 @@ def get_default_x86_inductor_quantization_config():
return quantization_config
def get_supported_config_and_operators() -> List[OperatorConfig]:
return get_supported_x86_inductor_config_and_operators()
def _get_supported_config_and_operators() -> List[OperatorConfig]:
return _get_supported_x86_inductor_config_and_operators()
class X86InductorQuantizer(Quantizer):
supported_config_and_operators = get_supported_config_and_operators()
supported_config_and_operators = _get_supported_config_and_operators()
def __init__(self):
super().__init__()

View File

@ -1,7 +1,7 @@
import torch
from torch.fx import GraphModule
from ..utils import _get_aten_graph_module
from ..utils import _remove_tensor_overload_for_qdq_ops
from ..utils import get_aten_graph_module
from ..utils import remove_tensor_overload_for_qdq_ops
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torch.fx.subgraph_rewriter import replace_pattern
@ -117,11 +117,11 @@ _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS = [
]
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
_remove_tensor_overload_for_qdq_ops(model)
remove_tensor_overload_for_qdq_ops(model)
for example_inputs, pattern, replacement in _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS:
pattern = _get_aten_graph_module(pattern, example_inputs)
_remove_tensor_overload_for_qdq_ops(pattern)
replacement = _get_aten_graph_module(replacement, example_inputs)
_remove_tensor_overload_for_qdq_ops(replacement)
pattern = get_aten_graph_module(pattern, example_inputs)
remove_tensor_overload_for_qdq_ops(pattern)
replacement = get_aten_graph_module(replacement, example_inputs)
remove_tensor_overload_for_qdq_ops(replacement)
matches = replace_pattern(model, pattern, replacement)
return model

View File

@ -15,6 +15,11 @@ import copy
import operator
from typing import Any, Callable, Dict, Optional, Tuple
__all__ = [
"fold_bn_weights_into_conv_node",
"get_aten_graph_module",
"remove_tensor_overload_for_qdq_ops",
]
def _get_tensor_constant_from_node(node, m):
if node is None:
@ -33,7 +38,7 @@ def _get_all_arguments(orig_args, orig_kwargs, args_schema):
all_args.append(schema.default_value)
return all_args
def _fold_bn_weights_into_conv_node(
def fold_bn_weights_into_conv_node(
conv_node: Node,
conv_weight_node: Node,
conv_bias_node: Optional[Node],
@ -111,7 +116,7 @@ def _fuse_conv_bn_(m: GraphModule) -> None:
conv_node = n
conv_weight_node = conv_node.args[1]
conv_bias_node = conv_node.args[2]
_fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m)
fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m)
m.graph.eliminate_dead_code()
m.recompile()
@ -177,7 +182,7 @@ def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
node_name_to_scope[n.name] = current_scope
return node_name_to_scope
def _get_aten_graph_module(
def get_aten_graph_module(
pattern: Callable,
example_inputs: Tuple[Any, ...],
**kwargs,
@ -198,7 +203,7 @@ def _get_aten_graph_module(
aten_pattern.recompile()
return aten_pattern
def _remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
""" Remove .tensor overload for quantize/dequantize ops so that we can
use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
"""
@ -258,8 +263,8 @@ def _replace_dropout_for_eval(m: GraphModule):
return F.dropout(x, p=0.5, training=False)
example_inputs = (torch.randn(1),)
match_pattern = _get_aten_graph_module(dropout_train, example_inputs)
replacement_pattern = _get_aten_graph_module(dropout_eval, example_inputs)
match_pattern = get_aten_graph_module(dropout_train, example_inputs)
replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
# Note: The match pattern looks like:
#