diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst index 0e99517f3ab..ba3eac56a6e 100644 --- a/docs/source/quantization-support.rst +++ b/docs/source/quantization-support.rst @@ -120,6 +120,13 @@ This module contains a few CustomConfig classes that's used in both eager mode a ConvertCustomConfig StandaloneModuleConfigEntry +torch.ao.quantization.pt2e (quantization in pytorch 2.0 export) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: torch.ao.quantization.pt2e +.. automodule:: torch.ao.quantization.pt2e.quantizer +.. automodule:: torch.ao.quantization.pt2e.representation + torch (quantization related functions) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py index c2280bee831..1764883efda 100644 --- a/test/inductor/test_inductor_freezing.py +++ b/test/inductor/test_inductor_freezing.py @@ -12,13 +12,13 @@ import weakref import torch import torch._dynamo as torchdynamo -import torch.ao.quantization._pt2e.quantizer.x86_inductor_quantizer as xiq +import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq from torch import nn from torch._inductor import config from torch._inductor.compile_fx import compile_fx from torch._inductor.utils import override_lowering, run_and_get_code -from torch.ao.quantization._pt2e.quantizer import X86InductorQuantizer from torch.ao.quantization._quantize_pt2e import convert_pt2e, prepare_pt2e_quantizer +from torch.ao.quantization.pt2e.quantizer import X86InductorQuantizer from torch.testing import FileCheck from torch.testing._internal.common_quantization import ( skipIfNoDynamoSupport, diff --git a/test/quantization/pt2e/test_graph_utils.py b/test/quantization/pt2e/test_graph_utils.py index 5a833bc98a6..a20338a97e8 100644 --- a/test/quantization/pt2e/test_graph_utils.py +++ b/test/quantization/pt2e/test_graph_utils.py @@ -5,7 +5,7 @@ import unittest import torch import torch._dynamo as torchdynamo -from torch.ao.quantization._pt2e.graph_utils import ( +from torch.ao.quantization.pt2e.graph_utils import ( find_sequential_partitions, get_equivalent_types, update_equivalent_types_dict, diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 64ac2b1891a..f3eaf2916ea 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -15,7 +15,7 @@ from torch.ao.quantization import ( ObserverOrFakeQuantize, QConfigMapping, ) -from torch.ao.quantization._pt2e.quantizer import ( +from torch.ao.quantization.pt2e.quantizer import ( ComposableQuantizer, DerivedQuantizationSpec, EmbeddingQuantizer, @@ -27,10 +27,10 @@ from torch.ao.quantization._pt2e.quantizer import ( Quantizer, SharedQuantizationSpec, ) -from torch.ao.quantization._pt2e.quantizer.composable_quantizer import ( # noqa: F811 +from torch.ao.quantization.pt2e.quantizer.composable_quantizer import ( # noqa: F811 ComposableQuantizer, ) -from torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer import ( +from torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer import ( get_symmetric_quantization_config, ) from torch.ao.quantization._quantize_pt2e import ( @@ -1774,7 +1774,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): def forward(self, x, y): return x + y - import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq + import torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer as qq quantizer = QNNPackQuantizer() operator_config = qq.get_symmetric_quantization_config(is_per_channel=True) @@ -1799,7 +1799,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): def forward(self, x, y): return x + y - import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq + import torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer as qq quantizer = QNNPackQuantizer() operator_config = qq.get_symmetric_quantization_config(is_per_channel=True) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index de7a8d77277..5030559db0f 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -3,7 +3,7 @@ import copy import torch import torch._dynamo as torchdynamo import torch.nn as nn -from torch.ao.quantization._pt2e.quantizer import ( +from torch.ao.quantization.pt2e.quantizer import ( X86InductorQuantizer, ) from torch.ao.quantization._quantize_pt2e import ( @@ -19,7 +19,7 @@ from torch.testing._internal.common_quantization import ( from torch.testing._internal.common_quantized import override_quantized_engine from enum import Enum import itertools -import torch.ao.quantization._pt2e.quantizer.x86_inductor_quantizer as xiq +import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle diff --git a/torch/_dynamo/skipfiles.py b/torch/_dynamo/skipfiles.py index d6bfdba0728..0287e6a86cc 100644 --- a/torch/_dynamo/skipfiles.py +++ b/torch/_dynamo/skipfiles.py @@ -139,10 +139,10 @@ FILENAME_ALLOWLIST |= {torch.utils._foreach_utils.__file__} # TODO: find a better way to express this path without having to import # `torch.ao.quantization._pt2e`, which interferes with memory profiling FILENAME_ALLOWLIST |= { - _module_dir(torch) + "ao/quantization/_pt2e/qat_utils.py", - _module_dir(torch) + "ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py", - _module_dir(torch) + "ao/quantization/_pt2e/representation/rewrite.py", - _module_dir(torch) + "ao/quantization/_pt2e/utils.py", + _module_dir(torch) + "ao/quantization/pt2e/qat_utils.py", + _module_dir(torch) + "ao/quantization/pt2e/quantizer/qnnpack_quantizer.py", + _module_dir(torch) + "ao/quantization/pt2e/representation/rewrite.py", + _module_dir(torch) + "ao/quantization/pt2e/utils.py", } # TODO (zhxchen17) Make exportdb importable here. diff --git a/torch/_inductor/freezing.py b/torch/_inductor/freezing.py index 06b8e867fcd..6ab0e8fdd52 100644 --- a/torch/_inductor/freezing.py +++ b/torch/_inductor/freezing.py @@ -13,7 +13,7 @@ from torch._functorch.compile_utils import fx_graph_cse from torch._inductor.compile_fx import fake_tensor_prop from torch._inductor.fx_passes.freezing_patterns import freezing_passes from torch._inductor.fx_passes.post_grad import view_to_reshape -from torch.ao.quantization._pt2e.utils import _fuse_conv_bn_ +from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ from torch.fx.experimental.proxy_tensor import make_fx from . import config from .decomposition import select_decomp_table diff --git a/torch/ao/quantization/_quantize_pt2e.py b/torch/ao/quantization/_quantize_pt2e.py index 47cea5baca5..5f6b7b11750 100644 --- a/torch/ao/quantization/_quantize_pt2e.py +++ b/torch/ao/quantization/_quantize_pt2e.py @@ -1,23 +1,23 @@ from torch.fx import GraphModule -from ._pt2e.prepare import prepare -from ._pt2e._propagate_annotation import propagate_annotation -from ._pt2e.qat_utils import ( +from .pt2e.prepare import prepare +from .pt2e._propagate_annotation import propagate_annotation +from .pt2e.qat_utils import ( _fuse_conv_bn_qat, _fold_conv_bn_qat, ) -from ._pt2e.utils import ( +from .pt2e.utils import ( _get_node_name_to_scope, _fuse_conv_bn_, _rearrange_weight_observer_for_decomposed_linear, _replace_dropout_for_eval, ) -from ._pt2e.representation import reference_representation_rewrite +from .pt2e.representation import reference_representation_rewrite from .fx.prepare import prepare as fx_prepare from .quantize_fx import _convert_to_reference_decomposed_fx from torch.ao.quantization import QConfigMapping # TODO: move quantizer to torch.ao.quantization -from torch.ao.quantization._pt2e.quantizer import ( # noqa: F401 +from torch.ao.quantization.pt2e.quantizer import ( # noqa: F401 OperatorConfig, OperatorPatternType, QuantizationConfig, @@ -32,13 +32,13 @@ from torch.ao.quantization._pt2e.quantizer import ( # noqa: F401 EmbeddingQuantizer, ComposableQuantizer, ) -from torch.ao.quantization._pt2e.quantizer.utils import ( # noqa: F401 +from torch.ao.quantization.pt2e.quantizer.utils import ( # noqa: F401 get_bias_qspec, get_input_act_qspec, get_output_act_qspec, get_weight_qspec, ) -from torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer import ( # noqa: F401 +from torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer import ( # noqa: F401 get_symmetric_quantization_config, ) from torch.ao.quantization.backend_config import BackendConfig diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 8f8fe9424f0..d14bf444ccd 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -106,7 +106,7 @@ from .custom_config import ( PrepareCustomConfig, StandaloneModuleConfigEntry, ) -from torch.ao.quantization._pt2e.quantizer import ( +from torch.ao.quantization.pt2e.quantizer import ( EdgeOrNode, QuantizationSpec, FixedQParamsQuantizationSpec, diff --git a/torch/ao/quantization/_pt2e/__init__.py b/torch/ao/quantization/pt2e/__init__.py similarity index 100% rename from torch/ao/quantization/_pt2e/__init__.py rename to torch/ao/quantization/pt2e/__init__.py diff --git a/torch/ao/quantization/_pt2e/_propagate_annotation.py b/torch/ao/quantization/pt2e/_propagate_annotation.py similarity index 97% rename from torch/ao/quantization/_pt2e/_propagate_annotation.py rename to torch/ao/quantization/pt2e/_propagate_annotation.py index 87b0725d02e..bdf2a3ca2fe 100644 --- a/torch/ao/quantization/_pt2e/_propagate_annotation.py +++ b/torch/ao/quantization/pt2e/_propagate_annotation.py @@ -1,7 +1,7 @@ from typing import Callable import torch -from torch.ao.quantization._pt2e.quantizer import ( +from torch.ao.quantization.pt2e.quantizer import ( QuantizationAnnotation, SharedQuantizationSpec, ) diff --git a/torch/ao/quantization/_pt2e/graph_utils.py b/torch/ao/quantization/pt2e/graph_utils.py similarity index 96% rename from torch/ao/quantization/_pt2e/graph_utils.py rename to torch/ao/quantization/pt2e/graph_utils.py index b58ca2f97cd..0776e6db56e 100644 --- a/torch/ao/quantization/_pt2e/graph_utils.py +++ b/torch/ao/quantization/pt2e/graph_utils.py @@ -10,6 +10,12 @@ from torch.fx.passes.utils.source_matcher_utils import ( SourcePartition, ) +__all__ = [ + "find_sequential_partitions", + "get_equivalent_types", + "update_equivalent_types_dict", +] + _EQUIVALENT_TYPES: List[Set] = [ {torch.nn.Conv2d, torch.nn.functional.conv2d}, {torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d}, diff --git a/torch/ao/quantization/_pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py similarity index 99% rename from torch/ao/quantization/_pt2e/prepare.py rename to torch/ao/quantization/pt2e/prepare.py index be5467a0f35..2a29a92a986 100644 --- a/torch/ao/quantization/_pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -19,7 +19,7 @@ from torch.ao.quantization import QConfigMapping from torch.ao.quantization.qconfig import QConfigAny from torch.ao.quantization.fx.custom_config import PrepareCustomConfig from typing import Dict, Tuple, Union, Any -from torch.ao.quantization._pt2e.quantizer import ( +from torch.ao.quantization.pt2e.quantizer import ( QuantizationAnnotation, EdgeOrNode, ) diff --git a/torch/ao/quantization/_pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py similarity index 98% rename from torch/ao/quantization/_pt2e/qat_utils.py rename to torch/ao/quantization/pt2e/qat_utils.py index 50b23c8e1ff..4693d279440 100644 --- a/torch/ao/quantization/_pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -15,8 +15,8 @@ from .quantizer import ( QuantizationSpecBase, ) from .utils import ( - _fold_bn_weights_into_conv_node, - _get_aten_graph_module, + fold_bn_weights_into_conv_node, + get_aten_graph_module, ) # Example inputs for `_conv2d_bn_pattern`, `_qat_conv2d_bn_pattern`, and `_qat_conv2d_bn_pattern_no_bias` @@ -496,7 +496,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: m.graph.eliminate_dead_code() m.recompile() example_inputs = _conv2d_bn_pattern_example_inputs - match_pattern = _get_aten_graph_module(_conv2d_bn_pattern, example_inputs) + match_pattern = get_aten_graph_module(_conv2d_bn_pattern, example_inputs) # Step (1): Replace patterns with conv bias # @@ -504,7 +504,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: # the replacement patterns for these two cases are substantially different. # TODO: use the public replace_pattern API once it also returns replacement nodes - replacement_pattern_with_conv_bias = _get_aten_graph_module( + replacement_pattern_with_conv_bias = get_aten_graph_module( _qat_conv2d_bn_pattern, example_inputs, ) @@ -519,7 +519,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: # Step (2): Replace patterns without conv bias - replacement_pattern_no_conv_bias = _get_aten_graph_module( + replacement_pattern_no_conv_bias = get_aten_graph_module( _qat_conv2d_bn_pattern_no_conv_bias, example_inputs, ) @@ -652,11 +652,11 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: match_pattern = _get_quantized_qat_conv2d_bn_pattern( is_per_channel, has_relu, has_bias, relu_is_inplace, ) - match_pattern = _get_aten_graph_module(match_pattern, example_inputs, **kwargs) + match_pattern = get_aten_graph_module(match_pattern, example_inputs, **kwargs) replacement_pattern = _get_folded_quantized_qat_conv2d_bn_pattern( is_per_channel, has_relu, has_bias, relu_is_inplace, ) - replacement_pattern = _get_aten_graph_module(replacement_pattern, example_inputs, **kwargs) + replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, **kwargs) replacements.extend( replace_pattern_with_filters( m, @@ -720,7 +720,7 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: ) # fold bn weights into conv - _fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m) + fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m) # Copy over literal args for conv for original_node in _filter_nodes_map(r.nodes_map).values(): diff --git a/torch/ao/quantization/_pt2e/quantizer/__init__.py b/torch/ao/quantization/pt2e/quantizer/__init__.py similarity index 100% rename from torch/ao/quantization/_pt2e/quantizer/__init__.py rename to torch/ao/quantization/pt2e/quantizer/__init__.py diff --git a/torch/ao/quantization/_pt2e/quantizer/composable_quantizer.py b/torch/ao/quantization/pt2e/quantizer/composable_quantizer.py similarity index 100% rename from torch/ao/quantization/_pt2e/quantizer/composable_quantizer.py rename to torch/ao/quantization/pt2e/quantizer/composable_quantizer.py diff --git a/torch/ao/quantization/_pt2e/quantizer/embedding_quantizer.py b/torch/ao/quantization/pt2e/quantizer/embedding_quantizer.py similarity index 95% rename from torch/ao/quantization/_pt2e/quantizer/embedding_quantizer.py rename to torch/ao/quantization/pt2e/quantizer/embedding_quantizer.py index 0546b2bdcc4..1aebcf2235c 100644 --- a/torch/ao/quantization/_pt2e/quantizer/embedding_quantizer.py +++ b/torch/ao/quantization/pt2e/quantizer/embedding_quantizer.py @@ -5,7 +5,7 @@ from typing import List, Set import torch import torch.nn.functional as F -from torch.ao.quantization._pt2e.quantizer.quantizer import ( +from torch.ao.quantization.pt2e.quantizer.quantizer import ( OperatorConfig, OperatorPatternType, QuantizationAnnotation, @@ -15,6 +15,10 @@ from torch.ao.quantization._pt2e.quantizer.quantizer import ( ) from torch.ao.quantization.observer import PerChannelMinMaxObserver +__all__ = [ + "get_embedding_operators_config", + "EmbeddingQuantizer", +] def get_embedding_operators_config() -> OperatorConfig: weight_quantization_spec = QuantizationSpec( diff --git a/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py b/torch/ao/quantization/pt2e/quantizer/qnnpack_quantizer.py similarity index 98% rename from torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py rename to torch/ao/quantization/pt2e/quantizer/qnnpack_quantizer.py index f241f4d1d4e..5f4ab1e7f46 100644 --- a/torch/ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py +++ b/torch/ao/quantization/pt2e/quantizer/qnnpack_quantizer.py @@ -11,9 +11,9 @@ import torch import torch._dynamo as torchdynamo import torch.nn.functional as F -from torch.ao.quantization._pt2e.graph_utils import find_sequential_partitions +from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions -from torch.ao.quantization._pt2e.quantizer.utils import ( +from torch.ao.quantization.pt2e.quantizer.utils import ( _annotate_input_qspec_map, _annotate_output_qspec, _is_sym_size_node, @@ -84,7 +84,7 @@ def _get_linear_patterns(input_size: List[int]): return [pattern_w_bias, pattern_wo_bias] -def supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]: +def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]: supported_operators: Dict[str, List[OperatorPatternType]] = { # Both conv and linear should be able to handle relu + hardtanh fusion since # those are clamp ops @@ -107,7 +107,7 @@ def supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternT return copy.deepcopy(supported_operators) -def get_supported_symmetric_config_and_operators() -> List[OperatorConfig]: +def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]: supported_config_and_operators: List[OperatorConfig] = [] for quantization_config in [ get_symmetric_quantization_config(), @@ -115,7 +115,7 @@ def get_supported_symmetric_config_and_operators() -> List[OperatorConfig]: get_symmetric_quantization_config(is_per_channel=True), get_symmetric_quantization_config(is_per_channel=True, is_qat=True), ]: - ops = supported_symmetric_quantized_operators() + ops = _supported_symmetric_quantized_operators() for op_string, pattern_list in ops.items(): supported_config_and_operators.append( OperatorConfig(quantization_config, pattern_list) @@ -205,8 +205,8 @@ def get_symmetric_quantization_config( return quantization_config -def get_supported_config_and_operators() -> List[OperatorConfig]: - return get_supported_symmetric_config_and_operators() +def _get_supported_config_and_operators() -> List[OperatorConfig]: + return _get_supported_symmetric_config_and_operators() def _is_annotated(nodes: List[Node]): @@ -225,7 +225,7 @@ def _is_annotated(nodes: List[Node]): class QNNPackQuantizer(Quantizer): - supported_config_and_operators = get_supported_config_and_operators() + supported_config_and_operators = _get_supported_config_and_operators() def __init__(self): super().__init__() diff --git a/torch/ao/quantization/_pt2e/quantizer/quantizer.py b/torch/ao/quantization/pt2e/quantizer/quantizer.py similarity index 93% rename from torch/ao/quantization/_pt2e/quantizer/quantizer.py rename to torch/ao/quantization/pt2e/quantizer/quantizer.py index 7b84cb8c444..3f7531c33a4 100644 --- a/torch/ao/quantization/_pt2e/quantizer/quantizer.py +++ b/torch/ao/quantization/pt2e/quantizer/quantizer.py @@ -13,11 +13,13 @@ __all__ = [ "QuantizationSpecBase", "QuantizationSpec", "FixedQParamsQuantizationSpec", + "EdgeOrNode", "SharedQuantizationSpec", "DerivedQuantizationSpec", "QuantizationAnnotation", "QuantizationConfig", "OperatorPatternType", + "OperatorConfig", ] # TODO: maybe remove torch.float32 @@ -86,17 +88,19 @@ class FixedQParamsQuantizationSpec(QuantizationSpecBase): quant_max: Optional[int] = None qscheme: Optional[torch.qscheme] = None +""" +The way we refer to other points of quantization in the graph will be either +an input edge or an output value +input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node] +output value is an fx Node +""" EdgeOrNode = Union[Tuple[Node, Node], Node] +EdgeOrNode.__module__ = "torch.ao.quantization.pt2e.quantizer.quantizer" @dataclass(eq=True, frozen=True) class SharedQuantizationSpec(QuantizationSpecBase): """ Quantization spec for the Tensors whose quantization parameters are shared with other Tensors - - The way we refer to other points of quantization in the graph will be either - an input edge or an output value - input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node] - output value is an fx Node """ edge_or_node: EdgeOrNode @@ -122,6 +126,7 @@ class QuantizationConfig: is_qat: bool = False OperatorPatternType = List[Callable] +OperatorPatternType.__module__ = "torch.ao.quantization.pt2e.quantizer.quantizer" OperatorConfig = NamedTuple( "OperatorConfig", diff --git a/torch/ao/quantization/_pt2e/quantizer/utils.py b/torch/ao/quantization/pt2e/quantizer/utils.py similarity index 95% rename from torch/ao/quantization/_pt2e/quantizer/utils.py rename to torch/ao/quantization/pt2e/quantizer/utils.py index 35f72ee1fa7..4c5dd439003 100644 --- a/torch/ao/quantization/_pt2e/quantizer/utils.py +++ b/torch/ao/quantization/pt2e/quantizer/utils.py @@ -1,13 +1,19 @@ from typing import List import torch -from torch.ao.quantization._pt2e.quantizer.quantizer import ( +from torch.ao.quantization.pt2e.quantizer.quantizer import ( QuantizationAnnotation, QuantizationConfig, QuantizationSpec, ) from torch.fx import Node +__all__ = [ + "get_input_act_qspec", + "get_output_act_qspec", + "get_weight_qspec", + "get_bias_qspec", +] def get_input_act_qspec(quantization_config: QuantizationConfig): if quantization_config is None: diff --git a/torch/ao/quantization/_pt2e/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/pt2e/quantizer/x86_inductor_quantizer.py similarity index 96% rename from torch/ao/quantization/_pt2e/quantizer/x86_inductor_quantizer.py rename to torch/ao/quantization/pt2e/quantizer/x86_inductor_quantizer.py index fe345a0b3b6..147c8cc983a 100644 --- a/torch/ao/quantization/_pt2e/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/pt2e/quantizer/x86_inductor_quantizer.py @@ -12,8 +12,8 @@ from .quantizer import ( Quantizer, QuantizationAnnotation, ) -from torch.ao.quantization._pt2e.graph_utils import find_sequential_partitions -from torch.ao.quantization._pt2e.quantizer.utils import ( +from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions +from torch.ao.quantization.pt2e.quantizer.utils import ( get_input_act_qspec, get_output_act_qspec, get_weight_qspec, @@ -40,7 +40,7 @@ __all__ = [ "get_default_x86_inductor_quantization_config", ] -def supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]: +def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]: # TODO: Add more supported operators here. supported_operators: Dict[str, List[OperatorPatternType]] = { "conv2d": [ @@ -69,10 +69,10 @@ def supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]: return copy.deepcopy(supported_operators) -def get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]: +def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]: supported_config_and_operators: List[OperatorConfig] = [] for quantization_config in [get_default_x86_inductor_quantization_config(), ]: - ops = supported_quantized_operators() + ops = _supported_quantized_operators() for op_string, pattern_list in ops.items(): supported_config_and_operators.append( OperatorConfig(quantization_config, pattern_list) @@ -120,12 +120,12 @@ def get_default_x86_inductor_quantization_config(): return quantization_config -def get_supported_config_and_operators() -> List[OperatorConfig]: - return get_supported_x86_inductor_config_and_operators() +def _get_supported_config_and_operators() -> List[OperatorConfig]: + return _get_supported_x86_inductor_config_and_operators() class X86InductorQuantizer(Quantizer): - supported_config_and_operators = get_supported_config_and_operators() + supported_config_and_operators = _get_supported_config_and_operators() def __init__(self): super().__init__() diff --git a/torch/ao/quantization/_pt2e/representation/__init__.py b/torch/ao/quantization/pt2e/representation/__init__.py similarity index 100% rename from torch/ao/quantization/_pt2e/representation/__init__.py rename to torch/ao/quantization/pt2e/representation/__init__.py diff --git a/torch/ao/quantization/_pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py similarity index 93% rename from torch/ao/quantization/_pt2e/representation/rewrite.py rename to torch/ao/quantization/pt2e/representation/rewrite.py index 8b12cb9fd92..7b3de95f5f6 100644 --- a/torch/ao/quantization/_pt2e/representation/rewrite.py +++ b/torch/ao/quantization/pt2e/representation/rewrite.py @@ -1,7 +1,7 @@ import torch from torch.fx import GraphModule -from ..utils import _get_aten_graph_module -from ..utils import _remove_tensor_overload_for_qdq_ops +from ..utils import get_aten_graph_module +from ..utils import remove_tensor_overload_for_qdq_ops from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 from torch.fx.subgraph_rewriter import replace_pattern @@ -117,11 +117,11 @@ _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS = [ ] def reference_representation_rewrite(model: GraphModule) -> GraphModule: - _remove_tensor_overload_for_qdq_ops(model) + remove_tensor_overload_for_qdq_ops(model) for example_inputs, pattern, replacement in _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS: - pattern = _get_aten_graph_module(pattern, example_inputs) - _remove_tensor_overload_for_qdq_ops(pattern) - replacement = _get_aten_graph_module(replacement, example_inputs) - _remove_tensor_overload_for_qdq_ops(replacement) + pattern = get_aten_graph_module(pattern, example_inputs) + remove_tensor_overload_for_qdq_ops(pattern) + replacement = get_aten_graph_module(replacement, example_inputs) + remove_tensor_overload_for_qdq_ops(replacement) matches = replace_pattern(model, pattern, replacement) return model diff --git a/torch/ao/quantization/_pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py similarity index 95% rename from torch/ao/quantization/_pt2e/utils.py rename to torch/ao/quantization/pt2e/utils.py index 865434dd80b..ea991d40eea 100644 --- a/torch/ao/quantization/_pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -15,6 +15,11 @@ import copy import operator from typing import Any, Callable, Dict, Optional, Tuple +__all__ = [ + "fold_bn_weights_into_conv_node", + "get_aten_graph_module", + "remove_tensor_overload_for_qdq_ops", +] def _get_tensor_constant_from_node(node, m): if node is None: @@ -33,7 +38,7 @@ def _get_all_arguments(orig_args, orig_kwargs, args_schema): all_args.append(schema.default_value) return all_args -def _fold_bn_weights_into_conv_node( +def fold_bn_weights_into_conv_node( conv_node: Node, conv_weight_node: Node, conv_bias_node: Optional[Node], @@ -111,7 +116,7 @@ def _fuse_conv_bn_(m: GraphModule) -> None: conv_node = n conv_weight_node = conv_node.args[1] conv_bias_node = conv_node.args[2] - _fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m) + fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m) m.graph.eliminate_dead_code() m.recompile() @@ -177,7 +182,7 @@ def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]: node_name_to_scope[n.name] = current_scope return node_name_to_scope -def _get_aten_graph_module( +def get_aten_graph_module( pattern: Callable, example_inputs: Tuple[Any, ...], **kwargs, @@ -198,7 +203,7 @@ def _get_aten_graph_module( aten_pattern.recompile() return aten_pattern -def _remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None: +def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None: """ Remove .tensor overload for quantize/dequantize ops so that we can use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e """ @@ -258,8 +263,8 @@ def _replace_dropout_for_eval(m: GraphModule): return F.dropout(x, p=0.5, training=False) example_inputs = (torch.randn(1),) - match_pattern = _get_aten_graph_module(dropout_train, example_inputs) - replacement_pattern = _get_aten_graph_module(dropout_eval, example_inputs) + match_pattern = get_aten_graph_module(dropout_train, example_inputs) + replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs) # Note: The match pattern looks like: #