mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[quant][pt2e] Rename _pt2e to pt2e (#104668)
Summary: X-link: https://github.com/pytorch/executorch/pull/3 att Test Plan: Imported from OSS Differential Revision: D47202807 Pull Request resolved: https://github.com/pytorch/pytorch/pull/104668 Approved by: https://github.com/andrewor14
This commit is contained in:
parent
a63f3f4335
commit
7b4d080496
|
|
@ -120,6 +120,13 @@ This module contains a few CustomConfig classes that's used in both eager mode a
|
||||||
ConvertCustomConfig
|
ConvertCustomConfig
|
||||||
StandaloneModuleConfigEntry
|
StandaloneModuleConfigEntry
|
||||||
|
|
||||||
|
torch.ao.quantization.pt2e (quantization in pytorch 2.0 export)
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. automodule:: torch.ao.quantization.pt2e
|
||||||
|
.. automodule:: torch.ao.quantization.pt2e.quantizer
|
||||||
|
.. automodule:: torch.ao.quantization.pt2e.representation
|
||||||
|
|
||||||
torch (quantization related functions)
|
torch (quantization related functions)
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,13 +12,13 @@ import weakref
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch._dynamo as torchdynamo
|
import torch._dynamo as torchdynamo
|
||||||
import torch.ao.quantization._pt2e.quantizer.x86_inductor_quantizer as xiq
|
import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
from torch._inductor.compile_fx import compile_fx
|
from torch._inductor.compile_fx import compile_fx
|
||||||
from torch._inductor.utils import override_lowering, run_and_get_code
|
from torch._inductor.utils import override_lowering, run_and_get_code
|
||||||
from torch.ao.quantization._pt2e.quantizer import X86InductorQuantizer
|
|
||||||
from torch.ao.quantization._quantize_pt2e import convert_pt2e, prepare_pt2e_quantizer
|
from torch.ao.quantization._quantize_pt2e import convert_pt2e, prepare_pt2e_quantizer
|
||||||
|
from torch.ao.quantization.pt2e.quantizer import X86InductorQuantizer
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.common_quantization import (
|
from torch.testing._internal.common_quantization import (
|
||||||
skipIfNoDynamoSupport,
|
skipIfNoDynamoSupport,
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import unittest
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo as torchdynamo
|
import torch._dynamo as torchdynamo
|
||||||
|
|
||||||
from torch.ao.quantization._pt2e.graph_utils import (
|
from torch.ao.quantization.pt2e.graph_utils import (
|
||||||
find_sequential_partitions,
|
find_sequential_partitions,
|
||||||
get_equivalent_types,
|
get_equivalent_types,
|
||||||
update_equivalent_types_dict,
|
update_equivalent_types_dict,
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from torch.ao.quantization import (
|
||||||
ObserverOrFakeQuantize,
|
ObserverOrFakeQuantize,
|
||||||
QConfigMapping,
|
QConfigMapping,
|
||||||
)
|
)
|
||||||
from torch.ao.quantization._pt2e.quantizer import (
|
from torch.ao.quantization.pt2e.quantizer import (
|
||||||
ComposableQuantizer,
|
ComposableQuantizer,
|
||||||
DerivedQuantizationSpec,
|
DerivedQuantizationSpec,
|
||||||
EmbeddingQuantizer,
|
EmbeddingQuantizer,
|
||||||
|
|
@ -27,10 +27,10 @@ from torch.ao.quantization._pt2e.quantizer import (
|
||||||
Quantizer,
|
Quantizer,
|
||||||
SharedQuantizationSpec,
|
SharedQuantizationSpec,
|
||||||
)
|
)
|
||||||
from torch.ao.quantization._pt2e.quantizer.composable_quantizer import ( # noqa: F811
|
from torch.ao.quantization.pt2e.quantizer.composable_quantizer import ( # noqa: F811
|
||||||
ComposableQuantizer,
|
ComposableQuantizer,
|
||||||
)
|
)
|
||||||
from torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer import (
|
from torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer import (
|
||||||
get_symmetric_quantization_config,
|
get_symmetric_quantization_config,
|
||||||
)
|
)
|
||||||
from torch.ao.quantization._quantize_pt2e import (
|
from torch.ao.quantization._quantize_pt2e import (
|
||||||
|
|
@ -1774,7 +1774,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
|
import torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer as qq
|
||||||
|
|
||||||
quantizer = QNNPackQuantizer()
|
quantizer = QNNPackQuantizer()
|
||||||
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
|
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
|
||||||
|
|
@ -1799,7 +1799,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
|
import torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer as qq
|
||||||
|
|
||||||
quantizer = QNNPackQuantizer()
|
quantizer = QNNPackQuantizer()
|
||||||
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
|
operator_config = qq.get_symmetric_quantization_config(is_per_channel=True)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import copy
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo as torchdynamo
|
import torch._dynamo as torchdynamo
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.ao.quantization._pt2e.quantizer import (
|
from torch.ao.quantization.pt2e.quantizer import (
|
||||||
X86InductorQuantizer,
|
X86InductorQuantizer,
|
||||||
)
|
)
|
||||||
from torch.ao.quantization._quantize_pt2e import (
|
from torch.ao.quantization._quantize_pt2e import (
|
||||||
|
|
@ -19,7 +19,7 @@ from torch.testing._internal.common_quantization import (
|
||||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import itertools
|
import itertools
|
||||||
import torch.ao.quantization._pt2e.quantizer.x86_inductor_quantizer as xiq
|
import torch.ao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
|
||||||
from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle
|
from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -139,10 +139,10 @@ FILENAME_ALLOWLIST |= {torch.utils._foreach_utils.__file__}
|
||||||
# TODO: find a better way to express this path without having to import
|
# TODO: find a better way to express this path without having to import
|
||||||
# `torch.ao.quantization._pt2e`, which interferes with memory profiling
|
# `torch.ao.quantization._pt2e`, which interferes with memory profiling
|
||||||
FILENAME_ALLOWLIST |= {
|
FILENAME_ALLOWLIST |= {
|
||||||
_module_dir(torch) + "ao/quantization/_pt2e/qat_utils.py",
|
_module_dir(torch) + "ao/quantization/pt2e/qat_utils.py",
|
||||||
_module_dir(torch) + "ao/quantization/_pt2e/quantizer/qnnpack_quantizer.py",
|
_module_dir(torch) + "ao/quantization/pt2e/quantizer/qnnpack_quantizer.py",
|
||||||
_module_dir(torch) + "ao/quantization/_pt2e/representation/rewrite.py",
|
_module_dir(torch) + "ao/quantization/pt2e/representation/rewrite.py",
|
||||||
_module_dir(torch) + "ao/quantization/_pt2e/utils.py",
|
_module_dir(torch) + "ao/quantization/pt2e/utils.py",
|
||||||
}
|
}
|
||||||
|
|
||||||
# TODO (zhxchen17) Make exportdb importable here.
|
# TODO (zhxchen17) Make exportdb importable here.
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from torch._functorch.compile_utils import fx_graph_cse
|
||||||
from torch._inductor.compile_fx import fake_tensor_prop
|
from torch._inductor.compile_fx import fake_tensor_prop
|
||||||
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
|
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
|
||||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||||
from torch.ao.quantization._pt2e.utils import _fuse_conv_bn_
|
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from . import config
|
from . import config
|
||||||
from .decomposition import select_decomp_table
|
from .decomposition import select_decomp_table
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,23 @@
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
from ._pt2e.prepare import prepare
|
from .pt2e.prepare import prepare
|
||||||
from ._pt2e._propagate_annotation import propagate_annotation
|
from .pt2e._propagate_annotation import propagate_annotation
|
||||||
from ._pt2e.qat_utils import (
|
from .pt2e.qat_utils import (
|
||||||
_fuse_conv_bn_qat,
|
_fuse_conv_bn_qat,
|
||||||
_fold_conv_bn_qat,
|
_fold_conv_bn_qat,
|
||||||
)
|
)
|
||||||
from ._pt2e.utils import (
|
from .pt2e.utils import (
|
||||||
_get_node_name_to_scope,
|
_get_node_name_to_scope,
|
||||||
_fuse_conv_bn_,
|
_fuse_conv_bn_,
|
||||||
_rearrange_weight_observer_for_decomposed_linear,
|
_rearrange_weight_observer_for_decomposed_linear,
|
||||||
_replace_dropout_for_eval,
|
_replace_dropout_for_eval,
|
||||||
)
|
)
|
||||||
from ._pt2e.representation import reference_representation_rewrite
|
from .pt2e.representation import reference_representation_rewrite
|
||||||
from .fx.prepare import prepare as fx_prepare
|
from .fx.prepare import prepare as fx_prepare
|
||||||
from .quantize_fx import _convert_to_reference_decomposed_fx
|
from .quantize_fx import _convert_to_reference_decomposed_fx
|
||||||
from torch.ao.quantization import QConfigMapping
|
from torch.ao.quantization import QConfigMapping
|
||||||
# TODO: move quantizer to torch.ao.quantization
|
# TODO: move quantizer to torch.ao.quantization
|
||||||
from torch.ao.quantization._pt2e.quantizer import ( # noqa: F401
|
from torch.ao.quantization.pt2e.quantizer import ( # noqa: F401
|
||||||
OperatorConfig,
|
OperatorConfig,
|
||||||
OperatorPatternType,
|
OperatorPatternType,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
|
|
@ -32,13 +32,13 @@ from torch.ao.quantization._pt2e.quantizer import ( # noqa: F401
|
||||||
EmbeddingQuantizer,
|
EmbeddingQuantizer,
|
||||||
ComposableQuantizer,
|
ComposableQuantizer,
|
||||||
)
|
)
|
||||||
from torch.ao.quantization._pt2e.quantizer.utils import ( # noqa: F401
|
from torch.ao.quantization.pt2e.quantizer.utils import ( # noqa: F401
|
||||||
get_bias_qspec,
|
get_bias_qspec,
|
||||||
get_input_act_qspec,
|
get_input_act_qspec,
|
||||||
get_output_act_qspec,
|
get_output_act_qspec,
|
||||||
get_weight_qspec,
|
get_weight_qspec,
|
||||||
)
|
)
|
||||||
from torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer import ( # noqa: F401
|
from torch.ao.quantization.pt2e.quantizer.qnnpack_quantizer import ( # noqa: F401
|
||||||
get_symmetric_quantization_config,
|
get_symmetric_quantization_config,
|
||||||
)
|
)
|
||||||
from torch.ao.quantization.backend_config import BackendConfig
|
from torch.ao.quantization.backend_config import BackendConfig
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,7 @@ from .custom_config import (
|
||||||
PrepareCustomConfig,
|
PrepareCustomConfig,
|
||||||
StandaloneModuleConfigEntry,
|
StandaloneModuleConfigEntry,
|
||||||
)
|
)
|
||||||
from torch.ao.quantization._pt2e.quantizer import (
|
from torch.ao.quantization.pt2e.quantizer import (
|
||||||
EdgeOrNode,
|
EdgeOrNode,
|
||||||
QuantizationSpec,
|
QuantizationSpec,
|
||||||
FixedQParamsQuantizationSpec,
|
FixedQParamsQuantizationSpec,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.ao.quantization._pt2e.quantizer import (
|
from torch.ao.quantization.pt2e.quantizer import (
|
||||||
QuantizationAnnotation,
|
QuantizationAnnotation,
|
||||||
SharedQuantizationSpec,
|
SharedQuantizationSpec,
|
||||||
)
|
)
|
||||||
|
|
@ -10,6 +10,12 @@ from torch.fx.passes.utils.source_matcher_utils import (
|
||||||
SourcePartition,
|
SourcePartition,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"find_sequential_partitions",
|
||||||
|
"get_equivalent_types",
|
||||||
|
"update_equivalent_types_dict",
|
||||||
|
]
|
||||||
|
|
||||||
_EQUIVALENT_TYPES: List[Set] = [
|
_EQUIVALENT_TYPES: List[Set] = [
|
||||||
{torch.nn.Conv2d, torch.nn.functional.conv2d},
|
{torch.nn.Conv2d, torch.nn.functional.conv2d},
|
||||||
{torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d},
|
{torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d},
|
||||||
|
|
@ -19,7 +19,7 @@ from torch.ao.quantization import QConfigMapping
|
||||||
from torch.ao.quantization.qconfig import QConfigAny
|
from torch.ao.quantization.qconfig import QConfigAny
|
||||||
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
|
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
|
||||||
from typing import Dict, Tuple, Union, Any
|
from typing import Dict, Tuple, Union, Any
|
||||||
from torch.ao.quantization._pt2e.quantizer import (
|
from torch.ao.quantization.pt2e.quantizer import (
|
||||||
QuantizationAnnotation,
|
QuantizationAnnotation,
|
||||||
EdgeOrNode,
|
EdgeOrNode,
|
||||||
)
|
)
|
||||||
|
|
@ -15,8 +15,8 @@ from .quantizer import (
|
||||||
QuantizationSpecBase,
|
QuantizationSpecBase,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
_fold_bn_weights_into_conv_node,
|
fold_bn_weights_into_conv_node,
|
||||||
_get_aten_graph_module,
|
get_aten_graph_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Example inputs for `_conv2d_bn_pattern`, `_qat_conv2d_bn_pattern`, and `_qat_conv2d_bn_pattern_no_bias`
|
# Example inputs for `_conv2d_bn_pattern`, `_qat_conv2d_bn_pattern`, and `_qat_conv2d_bn_pattern_no_bias`
|
||||||
|
|
@ -496,7 +496,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||||
m.graph.eliminate_dead_code()
|
m.graph.eliminate_dead_code()
|
||||||
m.recompile()
|
m.recompile()
|
||||||
example_inputs = _conv2d_bn_pattern_example_inputs
|
example_inputs = _conv2d_bn_pattern_example_inputs
|
||||||
match_pattern = _get_aten_graph_module(_conv2d_bn_pattern, example_inputs)
|
match_pattern = get_aten_graph_module(_conv2d_bn_pattern, example_inputs)
|
||||||
|
|
||||||
# Step (1): Replace patterns with conv bias
|
# Step (1): Replace patterns with conv bias
|
||||||
#
|
#
|
||||||
|
|
@ -504,7 +504,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||||
# the replacement patterns for these two cases are substantially different.
|
# the replacement patterns for these two cases are substantially different.
|
||||||
# TODO: use the public replace_pattern API once it also returns replacement nodes
|
# TODO: use the public replace_pattern API once it also returns replacement nodes
|
||||||
|
|
||||||
replacement_pattern_with_conv_bias = _get_aten_graph_module(
|
replacement_pattern_with_conv_bias = get_aten_graph_module(
|
||||||
_qat_conv2d_bn_pattern,
|
_qat_conv2d_bn_pattern,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
)
|
)
|
||||||
|
|
@ -519,7 +519,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||||
|
|
||||||
# Step (2): Replace patterns without conv bias
|
# Step (2): Replace patterns without conv bias
|
||||||
|
|
||||||
replacement_pattern_no_conv_bias = _get_aten_graph_module(
|
replacement_pattern_no_conv_bias = get_aten_graph_module(
|
||||||
_qat_conv2d_bn_pattern_no_conv_bias,
|
_qat_conv2d_bn_pattern_no_conv_bias,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
)
|
)
|
||||||
|
|
@ -652,11 +652,11 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||||
match_pattern = _get_quantized_qat_conv2d_bn_pattern(
|
match_pattern = _get_quantized_qat_conv2d_bn_pattern(
|
||||||
is_per_channel, has_relu, has_bias, relu_is_inplace,
|
is_per_channel, has_relu, has_bias, relu_is_inplace,
|
||||||
)
|
)
|
||||||
match_pattern = _get_aten_graph_module(match_pattern, example_inputs, **kwargs)
|
match_pattern = get_aten_graph_module(match_pattern, example_inputs, **kwargs)
|
||||||
replacement_pattern = _get_folded_quantized_qat_conv2d_bn_pattern(
|
replacement_pattern = _get_folded_quantized_qat_conv2d_bn_pattern(
|
||||||
is_per_channel, has_relu, has_bias, relu_is_inplace,
|
is_per_channel, has_relu, has_bias, relu_is_inplace,
|
||||||
)
|
)
|
||||||
replacement_pattern = _get_aten_graph_module(replacement_pattern, example_inputs, **kwargs)
|
replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, **kwargs)
|
||||||
replacements.extend(
|
replacements.extend(
|
||||||
replace_pattern_with_filters(
|
replace_pattern_with_filters(
|
||||||
m,
|
m,
|
||||||
|
|
@ -720,7 +720,7 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||||
)
|
)
|
||||||
|
|
||||||
# fold bn weights into conv
|
# fold bn weights into conv
|
||||||
_fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
|
fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
|
||||||
|
|
||||||
# Copy over literal args for conv
|
# Copy over literal args for conv
|
||||||
for original_node in _filter_nodes_map(r.nodes_map).values():
|
for original_node in _filter_nodes_map(r.nodes_map).values():
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import List, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.ao.quantization._pt2e.quantizer.quantizer import (
|
from torch.ao.quantization.pt2e.quantizer.quantizer import (
|
||||||
OperatorConfig,
|
OperatorConfig,
|
||||||
OperatorPatternType,
|
OperatorPatternType,
|
||||||
QuantizationAnnotation,
|
QuantizationAnnotation,
|
||||||
|
|
@ -15,6 +15,10 @@ from torch.ao.quantization._pt2e.quantizer.quantizer import (
|
||||||
)
|
)
|
||||||
from torch.ao.quantization.observer import PerChannelMinMaxObserver
|
from torch.ao.quantization.observer import PerChannelMinMaxObserver
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_embedding_operators_config",
|
||||||
|
"EmbeddingQuantizer",
|
||||||
|
]
|
||||||
|
|
||||||
def get_embedding_operators_config() -> OperatorConfig:
|
def get_embedding_operators_config() -> OperatorConfig:
|
||||||
weight_quantization_spec = QuantizationSpec(
|
weight_quantization_spec = QuantizationSpec(
|
||||||
|
|
@ -11,9 +11,9 @@ import torch
|
||||||
import torch._dynamo as torchdynamo
|
import torch._dynamo as torchdynamo
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from torch.ao.quantization._pt2e.graph_utils import find_sequential_partitions
|
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
||||||
|
|
||||||
from torch.ao.quantization._pt2e.quantizer.utils import (
|
from torch.ao.quantization.pt2e.quantizer.utils import (
|
||||||
_annotate_input_qspec_map,
|
_annotate_input_qspec_map,
|
||||||
_annotate_output_qspec,
|
_annotate_output_qspec,
|
||||||
_is_sym_size_node,
|
_is_sym_size_node,
|
||||||
|
|
@ -84,7 +84,7 @@ def _get_linear_patterns(input_size: List[int]):
|
||||||
return [pattern_w_bias, pattern_wo_bias]
|
return [pattern_w_bias, pattern_wo_bias]
|
||||||
|
|
||||||
|
|
||||||
def supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
||||||
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
||||||
# Both conv and linear should be able to handle relu + hardtanh fusion since
|
# Both conv and linear should be able to handle relu + hardtanh fusion since
|
||||||
# those are clamp ops
|
# those are clamp ops
|
||||||
|
|
@ -107,7 +107,7 @@ def supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternT
|
||||||
return copy.deepcopy(supported_operators)
|
return copy.deepcopy(supported_operators)
|
||||||
|
|
||||||
|
|
||||||
def get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
|
def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
|
||||||
supported_config_and_operators: List[OperatorConfig] = []
|
supported_config_and_operators: List[OperatorConfig] = []
|
||||||
for quantization_config in [
|
for quantization_config in [
|
||||||
get_symmetric_quantization_config(),
|
get_symmetric_quantization_config(),
|
||||||
|
|
@ -115,7 +115,7 @@ def get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
|
||||||
get_symmetric_quantization_config(is_per_channel=True),
|
get_symmetric_quantization_config(is_per_channel=True),
|
||||||
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
|
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
|
||||||
]:
|
]:
|
||||||
ops = supported_symmetric_quantized_operators()
|
ops = _supported_symmetric_quantized_operators()
|
||||||
for op_string, pattern_list in ops.items():
|
for op_string, pattern_list in ops.items():
|
||||||
supported_config_and_operators.append(
|
supported_config_and_operators.append(
|
||||||
OperatorConfig(quantization_config, pattern_list)
|
OperatorConfig(quantization_config, pattern_list)
|
||||||
|
|
@ -205,8 +205,8 @@ def get_symmetric_quantization_config(
|
||||||
return quantization_config
|
return quantization_config
|
||||||
|
|
||||||
|
|
||||||
def get_supported_config_and_operators() -> List[OperatorConfig]:
|
def _get_supported_config_and_operators() -> List[OperatorConfig]:
|
||||||
return get_supported_symmetric_config_and_operators()
|
return _get_supported_symmetric_config_and_operators()
|
||||||
|
|
||||||
|
|
||||||
def _is_annotated(nodes: List[Node]):
|
def _is_annotated(nodes: List[Node]):
|
||||||
|
|
@ -225,7 +225,7 @@ def _is_annotated(nodes: List[Node]):
|
||||||
|
|
||||||
|
|
||||||
class QNNPackQuantizer(Quantizer):
|
class QNNPackQuantizer(Quantizer):
|
||||||
supported_config_and_operators = get_supported_config_and_operators()
|
supported_config_and_operators = _get_supported_config_and_operators()
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -13,11 +13,13 @@ __all__ = [
|
||||||
"QuantizationSpecBase",
|
"QuantizationSpecBase",
|
||||||
"QuantizationSpec",
|
"QuantizationSpec",
|
||||||
"FixedQParamsQuantizationSpec",
|
"FixedQParamsQuantizationSpec",
|
||||||
|
"EdgeOrNode",
|
||||||
"SharedQuantizationSpec",
|
"SharedQuantizationSpec",
|
||||||
"DerivedQuantizationSpec",
|
"DerivedQuantizationSpec",
|
||||||
"QuantizationAnnotation",
|
"QuantizationAnnotation",
|
||||||
"QuantizationConfig",
|
"QuantizationConfig",
|
||||||
"OperatorPatternType",
|
"OperatorPatternType",
|
||||||
|
"OperatorConfig",
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO: maybe remove torch.float32
|
# TODO: maybe remove torch.float32
|
||||||
|
|
@ -86,17 +88,19 @@ class FixedQParamsQuantizationSpec(QuantizationSpecBase):
|
||||||
quant_max: Optional[int] = None
|
quant_max: Optional[int] = None
|
||||||
qscheme: Optional[torch.qscheme] = None
|
qscheme: Optional[torch.qscheme] = None
|
||||||
|
|
||||||
|
"""
|
||||||
|
The way we refer to other points of quantization in the graph will be either
|
||||||
|
an input edge or an output value
|
||||||
|
input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]
|
||||||
|
output value is an fx Node
|
||||||
|
"""
|
||||||
EdgeOrNode = Union[Tuple[Node, Node], Node]
|
EdgeOrNode = Union[Tuple[Node, Node], Node]
|
||||||
|
EdgeOrNode.__module__ = "torch.ao.quantization.pt2e.quantizer.quantizer"
|
||||||
|
|
||||||
@dataclass(eq=True, frozen=True)
|
@dataclass(eq=True, frozen=True)
|
||||||
class SharedQuantizationSpec(QuantizationSpecBase):
|
class SharedQuantizationSpec(QuantizationSpecBase):
|
||||||
"""
|
"""
|
||||||
Quantization spec for the Tensors whose quantization parameters are shared with other Tensors
|
Quantization spec for the Tensors whose quantization parameters are shared with other Tensors
|
||||||
|
|
||||||
The way we refer to other points of quantization in the graph will be either
|
|
||||||
an input edge or an output value
|
|
||||||
input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]
|
|
||||||
output value is an fx Node
|
|
||||||
"""
|
"""
|
||||||
edge_or_node: EdgeOrNode
|
edge_or_node: EdgeOrNode
|
||||||
|
|
||||||
|
|
@ -122,6 +126,7 @@ class QuantizationConfig:
|
||||||
is_qat: bool = False
|
is_qat: bool = False
|
||||||
|
|
||||||
OperatorPatternType = List[Callable]
|
OperatorPatternType = List[Callable]
|
||||||
|
OperatorPatternType.__module__ = "torch.ao.quantization.pt2e.quantizer.quantizer"
|
||||||
|
|
||||||
OperatorConfig = NamedTuple(
|
OperatorConfig = NamedTuple(
|
||||||
"OperatorConfig",
|
"OperatorConfig",
|
||||||
|
|
@ -1,13 +1,19 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.ao.quantization._pt2e.quantizer.quantizer import (
|
from torch.ao.quantization.pt2e.quantizer.quantizer import (
|
||||||
QuantizationAnnotation,
|
QuantizationAnnotation,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizationSpec,
|
QuantizationSpec,
|
||||||
)
|
)
|
||||||
from torch.fx import Node
|
from torch.fx import Node
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_input_act_qspec",
|
||||||
|
"get_output_act_qspec",
|
||||||
|
"get_weight_qspec",
|
||||||
|
"get_bias_qspec",
|
||||||
|
]
|
||||||
|
|
||||||
def get_input_act_qspec(quantization_config: QuantizationConfig):
|
def get_input_act_qspec(quantization_config: QuantizationConfig):
|
||||||
if quantization_config is None:
|
if quantization_config is None:
|
||||||
|
|
@ -12,8 +12,8 @@ from .quantizer import (
|
||||||
Quantizer,
|
Quantizer,
|
||||||
QuantizationAnnotation,
|
QuantizationAnnotation,
|
||||||
)
|
)
|
||||||
from torch.ao.quantization._pt2e.graph_utils import find_sequential_partitions
|
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
||||||
from torch.ao.quantization._pt2e.quantizer.utils import (
|
from torch.ao.quantization.pt2e.quantizer.utils import (
|
||||||
get_input_act_qspec,
|
get_input_act_qspec,
|
||||||
get_output_act_qspec,
|
get_output_act_qspec,
|
||||||
get_weight_qspec,
|
get_weight_qspec,
|
||||||
|
|
@ -40,7 +40,7 @@ __all__ = [
|
||||||
"get_default_x86_inductor_quantization_config",
|
"get_default_x86_inductor_quantization_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
def supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
||||||
# TODO: Add more supported operators here.
|
# TODO: Add more supported operators here.
|
||||||
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
||||||
"conv2d": [
|
"conv2d": [
|
||||||
|
|
@ -69,10 +69,10 @@ def supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
||||||
return copy.deepcopy(supported_operators)
|
return copy.deepcopy(supported_operators)
|
||||||
|
|
||||||
|
|
||||||
def get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]:
|
def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]:
|
||||||
supported_config_and_operators: List[OperatorConfig] = []
|
supported_config_and_operators: List[OperatorConfig] = []
|
||||||
for quantization_config in [get_default_x86_inductor_quantization_config(), ]:
|
for quantization_config in [get_default_x86_inductor_quantization_config(), ]:
|
||||||
ops = supported_quantized_operators()
|
ops = _supported_quantized_operators()
|
||||||
for op_string, pattern_list in ops.items():
|
for op_string, pattern_list in ops.items():
|
||||||
supported_config_and_operators.append(
|
supported_config_and_operators.append(
|
||||||
OperatorConfig(quantization_config, pattern_list)
|
OperatorConfig(quantization_config, pattern_list)
|
||||||
|
|
@ -120,12 +120,12 @@ def get_default_x86_inductor_quantization_config():
|
||||||
return quantization_config
|
return quantization_config
|
||||||
|
|
||||||
|
|
||||||
def get_supported_config_and_operators() -> List[OperatorConfig]:
|
def _get_supported_config_and_operators() -> List[OperatorConfig]:
|
||||||
return get_supported_x86_inductor_config_and_operators()
|
return _get_supported_x86_inductor_config_and_operators()
|
||||||
|
|
||||||
|
|
||||||
class X86InductorQuantizer(Quantizer):
|
class X86InductorQuantizer(Quantizer):
|
||||||
supported_config_and_operators = get_supported_config_and_operators()
|
supported_config_and_operators = _get_supported_config_and_operators()
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from ..utils import _get_aten_graph_module
|
from ..utils import get_aten_graph_module
|
||||||
from ..utils import _remove_tensor_overload_for_qdq_ops
|
from ..utils import remove_tensor_overload_for_qdq_ops
|
||||||
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
|
||||||
from torch.fx.subgraph_rewriter import replace_pattern
|
from torch.fx.subgraph_rewriter import replace_pattern
|
||||||
|
|
||||||
|
|
@ -117,11 +117,11 @@ _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
|
||||||
_remove_tensor_overload_for_qdq_ops(model)
|
remove_tensor_overload_for_qdq_ops(model)
|
||||||
for example_inputs, pattern, replacement in _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS:
|
for example_inputs, pattern, replacement in _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS:
|
||||||
pattern = _get_aten_graph_module(pattern, example_inputs)
|
pattern = get_aten_graph_module(pattern, example_inputs)
|
||||||
_remove_tensor_overload_for_qdq_ops(pattern)
|
remove_tensor_overload_for_qdq_ops(pattern)
|
||||||
replacement = _get_aten_graph_module(replacement, example_inputs)
|
replacement = get_aten_graph_module(replacement, example_inputs)
|
||||||
_remove_tensor_overload_for_qdq_ops(replacement)
|
remove_tensor_overload_for_qdq_ops(replacement)
|
||||||
matches = replace_pattern(model, pattern, replacement)
|
matches = replace_pattern(model, pattern, replacement)
|
||||||
return model
|
return model
|
||||||
|
|
@ -15,6 +15,11 @@ import copy
|
||||||
import operator
|
import operator
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"fold_bn_weights_into_conv_node",
|
||||||
|
"get_aten_graph_module",
|
||||||
|
"remove_tensor_overload_for_qdq_ops",
|
||||||
|
]
|
||||||
|
|
||||||
def _get_tensor_constant_from_node(node, m):
|
def _get_tensor_constant_from_node(node, m):
|
||||||
if node is None:
|
if node is None:
|
||||||
|
|
@ -33,7 +38,7 @@ def _get_all_arguments(orig_args, orig_kwargs, args_schema):
|
||||||
all_args.append(schema.default_value)
|
all_args.append(schema.default_value)
|
||||||
return all_args
|
return all_args
|
||||||
|
|
||||||
def _fold_bn_weights_into_conv_node(
|
def fold_bn_weights_into_conv_node(
|
||||||
conv_node: Node,
|
conv_node: Node,
|
||||||
conv_weight_node: Node,
|
conv_weight_node: Node,
|
||||||
conv_bias_node: Optional[Node],
|
conv_bias_node: Optional[Node],
|
||||||
|
|
@ -111,7 +116,7 @@ def _fuse_conv_bn_(m: GraphModule) -> None:
|
||||||
conv_node = n
|
conv_node = n
|
||||||
conv_weight_node = conv_node.args[1]
|
conv_weight_node = conv_node.args[1]
|
||||||
conv_bias_node = conv_node.args[2]
|
conv_bias_node = conv_node.args[2]
|
||||||
_fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m)
|
fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m)
|
||||||
|
|
||||||
m.graph.eliminate_dead_code()
|
m.graph.eliminate_dead_code()
|
||||||
m.recompile()
|
m.recompile()
|
||||||
|
|
@ -177,7 +182,7 @@ def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]:
|
||||||
node_name_to_scope[n.name] = current_scope
|
node_name_to_scope[n.name] = current_scope
|
||||||
return node_name_to_scope
|
return node_name_to_scope
|
||||||
|
|
||||||
def _get_aten_graph_module(
|
def get_aten_graph_module(
|
||||||
pattern: Callable,
|
pattern: Callable,
|
||||||
example_inputs: Tuple[Any, ...],
|
example_inputs: Tuple[Any, ...],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|
@ -198,7 +203,7 @@ def _get_aten_graph_module(
|
||||||
aten_pattern.recompile()
|
aten_pattern.recompile()
|
||||||
return aten_pattern
|
return aten_pattern
|
||||||
|
|
||||||
def _remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
|
def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
|
||||||
""" Remove .tensor overload for quantize/dequantize ops so that we can
|
""" Remove .tensor overload for quantize/dequantize ops so that we can
|
||||||
use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
|
use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
|
||||||
"""
|
"""
|
||||||
|
|
@ -258,8 +263,8 @@ def _replace_dropout_for_eval(m: GraphModule):
|
||||||
return F.dropout(x, p=0.5, training=False)
|
return F.dropout(x, p=0.5, training=False)
|
||||||
|
|
||||||
example_inputs = (torch.randn(1),)
|
example_inputs = (torch.randn(1),)
|
||||||
match_pattern = _get_aten_graph_module(dropout_train, example_inputs)
|
match_pattern = get_aten_graph_module(dropout_train, example_inputs)
|
||||||
replacement_pattern = _get_aten_graph_module(dropout_eval, example_inputs)
|
replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
|
||||||
|
|
||||||
# Note: The match pattern looks like:
|
# Note: The match pattern looks like:
|
||||||
#
|
#
|
||||||
Loading…
Reference in New Issue
Block a user