mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PT2E][Quantization] Refactor Quantizer and QNNPACKQuantizer (#99063)
This diff renames quantization spec/config and operator config. It moves these
datastructures to base quantizer.
Base quantizer API now has get_supported_operators that returns list of
patterns that a quantizer quantizes.
There are two choices being debated for how to convey to user what a particular
quantizer will quantize.
1. Modules. We just convey what nn.Modules will be quantized. Of course that
does not mean that equivalent functional variants wont be quantized, however
for simplifity we just use nn.Module. If certain ops are quatnzied in fused
manner then that will considered internal details. Pros and cons of this
approach
pros:
- Simple. Only nn Modules are listed.
- User does not have to see fusion patterns.
Cons:
- confusing perhaps because it is not clear if supported = nn.Conv2d also
means that the quantizer supported functional.conv2d
- Hiding fusion pattern means user has no say in not fusing. Meaning if
conv2d + relu is fused and user configures to quantize only conv, quantizer
will also quantize the following relu as if conv2d + relu are fused.
2. Patterns. Be explicit about what is supported and enumerate all possible
compbinations.
Pros:
- it is very clear what quantizer will do. no surprises.
Cons:
- It is not simple to parse.
- It can be argued taht fusion is internal detail of the quantizer. So some
quantizer implementation may chose to expose fusion patterns, while others
may not and may not even provide any configurability.
One option is to move set_supported_operators/modules out of base quantizer and
let each quantizer define its own way of communicating what is supported. Issue
with this is that when we want to "Compose" multiple quantizers there is no way
for user to define the order of composition if user does not know what a
quantizer supports. For exampl quantizer A may quantizer conv + relu while B
only conv, but B's implementation is fast. In that case you may compose (B, A)
such B quantizes conv and A quantizes relu. Not knowning what A
and B support, makes such composition harder
Differential Revision: [D44895547](https://our.internmc.facebook.com/intern/diff/D44895547/)
**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D44895547/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99063
Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
888c65b6a4
commit
31f311a816
|
|
@ -1,39 +1,46 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
import copy
|
||||
import itertools
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch._dynamo as torchdynamo
|
||||
import torch.nn as nn
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
from torch.ao.ns.fx.utils import compute_sqnr
|
||||
from torch.ao.quantization import get_default_qconfig, observer, QConfigMapping
|
||||
from torch.ao.quantization._pt2e.quantizer import (
|
||||
OperatorConfig,
|
||||
QNNPackQuantizer,
|
||||
Quantizer,
|
||||
)
|
||||
from torch.ao.quantization._quantize_pt2e import (
|
||||
convert_pt2e,
|
||||
prepare_pt2e,
|
||||
prepare_pt2e_quantizer,
|
||||
)
|
||||
from torch.ao.quantization.backend_config import get_qnnpack_backend_config
|
||||
from torch.ao.quantization.backend_config._qnnpack_pt2e import (
|
||||
get_qnnpack_pt2e_backend_config,
|
||||
)
|
||||
from torch.ao.quantization.backend_config._x86_inductor_pt2e import (
|
||||
get_x86_inductor_pt2e_backend_config,
|
||||
)
|
||||
from torch.ao.quantization.backend_config.x86 import get_x86_backend_config
|
||||
from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig
|
||||
from torch.ao.quantization.quantize_fx import (
|
||||
convert_fx,
|
||||
convert_to_reference_fx,
|
||||
prepare_fx,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import (
|
||||
NodeSpec as ns,
|
||||
QuantizationTestCase,
|
||||
skip_if_no_torchvision,
|
||||
skipIfNoQNNPACK,
|
||||
skipIfNoX86,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import NodeSpec as ns
|
||||
from torch.testing._internal.common_quantized import (
|
||||
override_quantized_engine,
|
||||
)
|
||||
from torch.ao.quantization import (
|
||||
get_default_qconfig,
|
||||
QConfigMapping,
|
||||
observer,
|
||||
)
|
||||
from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig
|
||||
from torch.ao.quantization.backend_config import (
|
||||
get_qnnpack_backend_config,
|
||||
)
|
||||
from torch.ao.quantization.backend_config._qnnpack_pt2e import get_qnnpack_pt2e_backend_config
|
||||
from torch.ao.quantization.backend_config._x86_inductor_pt2e import get_x86_inductor_pt2e_backend_config
|
||||
from torch.ao.quantization.backend_config.x86 import get_x86_backend_config
|
||||
from torch.ao.quantization.quantize_fx import prepare_fx, convert_to_reference_fx, convert_fx
|
||||
from torch.ao.quantization._pt2e.quantizer import Quantizer
|
||||
from torch.ao.quantization._pt2e.quantizer import QNNPackQuantizer
|
||||
from torch.ao.quantization._quantize_pt2e import prepare_pt2e, convert_pt2e, prepare_pt2e_quantizer
|
||||
from torch.ao.ns.fx.utils import (
|
||||
compute_sqnr,
|
||||
)
|
||||
import copy
|
||||
import itertools
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
|
|
@ -62,8 +69,9 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
)
|
||||
|
||||
qconfig = get_default_qconfig("qnnpack")
|
||||
qconfig_mapping = QConfigMapping().set_global(qconfig) \
|
||||
.set_module_name("conv2", None)
|
||||
qconfig_mapping = (
|
||||
QConfigMapping().set_global(qconfig).set_module_name("conv2", None)
|
||||
)
|
||||
backend_config = get_qnnpack_pt2e_backend_config()
|
||||
m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
|
||||
m(*example_inputs)
|
||||
|
|
@ -74,7 +82,9 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
node_occurrence = {
|
||||
# two for input of the first conv, one for output for the first conv
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 3,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor
|
||||
): 3,
|
||||
}
|
||||
node_list = [
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
|
||||
|
|
@ -84,7 +94,10 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
ns.call_function(torch.ops.aten.convolution.default),
|
||||
]
|
||||
self.checkGraphModuleNodes(
|
||||
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence)
|
||||
m,
|
||||
expected_node_list=node_list,
|
||||
expected_node_occurrence=node_occurrence,
|
||||
)
|
||||
|
||||
def test_qconfig_module_type(self):
|
||||
class M(torch.nn.Module):
|
||||
|
|
@ -122,7 +135,9 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
node_occurrence = {
|
||||
# two for input and weight of the conv, one for output for the conv
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 3,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor
|
||||
): 3,
|
||||
}
|
||||
node_list = [
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
|
||||
|
|
@ -145,17 +160,28 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
_DEFAULT_TARGET_DTYPE_INFO = {
|
||||
"input_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(dtype=torch.float),
|
||||
"output_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(dtype=torch.float),
|
||||
"input_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(
|
||||
dtype=torch.float
|
||||
),
|
||||
"output_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(
|
||||
dtype=torch.float
|
||||
),
|
||||
}
|
||||
for node in model.graph.nodes:
|
||||
node.meta["target_dtype_info"] = copy.deepcopy(_DEFAULT_TARGET_DTYPE_INFO)
|
||||
node.meta["target_dtype_info"] = copy.deepcopy(
|
||||
_DEFAULT_TARGET_DTYPE_INFO
|
||||
)
|
||||
for node in model.graph.nodes:
|
||||
if node.op == "call_function" and node.target == torch.ops.aten.convolution.default:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.convolution.default
|
||||
):
|
||||
node.meta["target_dtype_info"] = {
|
||||
"input_act_obs_or_fq_ctr": observer.default_observer,
|
||||
"weight_obs_or_fq_ctr": observer.default_weight_observer,
|
||||
"bias_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(dtype=torch.float),
|
||||
"bias_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(
|
||||
dtype=torch.float
|
||||
),
|
||||
"output_act_obs_or_fq_ctr": observer.default_observer,
|
||||
"weight_index": 1,
|
||||
"bias_index": 2,
|
||||
|
|
@ -164,6 +190,10 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_supported_operators(cls) -> List[OperatorConfig]:
|
||||
pass
|
||||
|
||||
m = M().eval()
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
|
||||
|
|
@ -189,7 +219,8 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
|
||||
]
|
||||
self.checkGraphModuleNodes(
|
||||
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence)
|
||||
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
|
||||
)
|
||||
|
||||
def test_qnnpack_quantizer_conv(self):
|
||||
class M(torch.nn.Module):
|
||||
|
|
@ -201,9 +232,12 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
return self.conv(x)
|
||||
|
||||
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
|
||||
|
||||
quantizer = QNNPackQuantizer()
|
||||
operator_spec = qq.get_default_per_channel_symmetric_qnnpack_operator_spec()
|
||||
quantizer.set_global(operator_spec)
|
||||
operator_config = (
|
||||
qq.get_default_per_channel_symmetric_qnnpack_quantization_config()
|
||||
)
|
||||
quantizer.set_global(operator_config)
|
||||
m = M().eval()
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
|
||||
|
|
@ -232,7 +266,8 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
|
||||
]
|
||||
self.checkGraphModuleNodes(
|
||||
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence)
|
||||
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
|
||||
)
|
||||
|
||||
def test_qnnpack_quantizer_obs_sharing_ops(self):
|
||||
class M(torch.nn.Module):
|
||||
|
|
@ -250,9 +285,12 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
return x
|
||||
|
||||
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
|
||||
|
||||
quantizer = QNNPackQuantizer()
|
||||
operator_spec = qq.get_default_per_channel_symmetric_qnnpack_operator_spec()
|
||||
quantizer.set_global(operator_spec)
|
||||
operator_config = (
|
||||
qq.get_default_per_channel_symmetric_qnnpack_quantization_config()
|
||||
)
|
||||
quantizer.set_global(operator_config)
|
||||
m = M().eval()
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
|
||||
|
|
@ -279,11 +317,9 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel),
|
||||
ns.call_function(torch.ops.aten.convolution.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
|
||||
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
|
||||
ns.call_function(torch.ops.aten.mean.dim),
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
|
||||
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
|
||||
ns.call_function(torch.ops.aten.hardtanh.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
|
||||
|
|
@ -293,7 +329,8 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
|
||||
]
|
||||
self.checkGraphModuleNodes(
|
||||
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence)
|
||||
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
|
||||
)
|
||||
|
||||
def test_rearrange_weight_observer_for_decomposed_linear(self):
|
||||
"""
|
||||
|
|
@ -305,6 +342,7 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
weight - observer - t \
|
||||
input - observer - addmm/mm
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, with_bias, use_relu):
|
||||
super().__init__()
|
||||
|
|
@ -331,7 +369,7 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
tracing_mode="real",
|
||||
)
|
||||
|
||||
qconfig = get_default_qconfig('qnnpack')
|
||||
qconfig = get_default_qconfig("qnnpack")
|
||||
qconfig_mapping = QConfigMapping().set_global(qconfig)
|
||||
backend_config = get_qnnpack_pt2e_backend_config()
|
||||
m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
|
||||
|
|
@ -339,12 +377,19 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
# 1. Check graph nodes:
|
||||
# - args[0] of t should be the weight observer
|
||||
# - args[-1] of addmm/mm should be t
|
||||
error_msg = 'Weight observer is not correctly rearranged for decomposed linear'
|
||||
error_msg = (
|
||||
"Weight observer is not correctly rearranged for decomposed linear"
|
||||
)
|
||||
for node in m.graph.nodes:
|
||||
if node.target == torch.ops.aten.t.default:
|
||||
target = node.args[0].target
|
||||
self.assertTrue(isinstance(getattr(m, target), observer.ObserverBase), error_msg)
|
||||
elif node.target in (torch.ops.aten.addmm.default, torch.ops.aten.mm.default):
|
||||
self.assertTrue(
|
||||
isinstance(getattr(m, target), observer.ObserverBase), error_msg
|
||||
)
|
||||
elif node.target in (
|
||||
torch.ops.aten.addmm.default,
|
||||
torch.ops.aten.mm.default,
|
||||
):
|
||||
target = node.args[-1].target
|
||||
self.assertTrue(target == torch.ops.aten.t.default, error_msg)
|
||||
|
||||
|
|
@ -380,7 +425,9 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
|
||||
node_occurrence = {
|
||||
ns.call_function(torch.ops.aten.convolution.default): 1,
|
||||
ns.call_function(torch.ops.aten._native_batch_norm_legit_no_training.default): 1,
|
||||
ns.call_function(
|
||||
torch.ops.aten._native_batch_norm_legit_no_training.default
|
||||
): 1,
|
||||
}
|
||||
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
||||
|
||||
|
|
@ -394,10 +441,13 @@ class TestQuantizePT2E(QuantizationTestCase):
|
|||
# make sure bn is fused into conv
|
||||
node_occurrence = {
|
||||
ns.call_function(torch.ops.aten.convolution.default): 1,
|
||||
ns.call_function(torch.ops.aten._native_batch_norm_legit_no_training.default): 0,
|
||||
ns.call_function(
|
||||
torch.ops.aten._native_batch_norm_legit_no_training.default
|
||||
): 0,
|
||||
}
|
||||
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
||||
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
class TestQuantizePT2EX86Inductor(QuantizationTestCase):
|
||||
@skipIfNoX86
|
||||
|
|
@ -417,7 +467,9 @@ class TestQuantizePT2EX86Inductor(QuantizationTestCase):
|
|||
inplace_relu_list = [True, False]
|
||||
with override_quantized_engine("x86"):
|
||||
with torch.no_grad():
|
||||
for use_relu, inplace_relu in itertools.product(use_relu_list, inplace_relu_list):
|
||||
for use_relu, inplace_relu in itertools.product(
|
||||
use_relu_list, inplace_relu_list
|
||||
):
|
||||
m = M(use_relu=use_relu, inplace_relu=inplace_relu).eval()
|
||||
example_inputs = (torch.randn(2, 3, 4, 4),)
|
||||
# program capture
|
||||
|
|
@ -433,7 +485,9 @@ class TestQuantizePT2EX86Inductor(QuantizationTestCase):
|
|||
qconfig = get_default_qconfig("x86")
|
||||
qconfig_mapping = QConfigMapping().set_global(qconfig)
|
||||
backend_config = get_x86_inductor_pt2e_backend_config()
|
||||
prepare_module = prepare_pt2e(export_module, qconfig_mapping, example_inputs, backend_config)
|
||||
prepare_module = prepare_pt2e(
|
||||
export_module, qconfig_mapping, example_inputs, backend_config
|
||||
)
|
||||
prepare_module(*example_inputs)
|
||||
convert_module = convert_pt2e(prepare_module)
|
||||
convert_module(*example_inputs)
|
||||
|
|
@ -441,38 +495,75 @@ class TestQuantizePT2EX86Inductor(QuantizationTestCase):
|
|||
# Fake quant should only be inserted at start and end
|
||||
node_occurrence = {
|
||||
# one for input and weight of the conv, one for output for the conv
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel): 1,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel): 1,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 2,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor
|
||||
): 2,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_channel
|
||||
): 1,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel
|
||||
): 1,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor
|
||||
): 2,
|
||||
}
|
||||
if use_relu:
|
||||
node_list = [
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor
|
||||
),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor
|
||||
),
|
||||
ns.call_function(torch.ops.aten.convolution.default),
|
||||
ns.call_function(torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
|
||||
ns.call_function(
|
||||
torch.ops.aten.relu_.default
|
||||
if inplace_relu
|
||||
else torch.ops.aten.relu.default
|
||||
),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor
|
||||
),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor
|
||||
),
|
||||
]
|
||||
else:
|
||||
node_list = [
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor
|
||||
),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor
|
||||
),
|
||||
ns.call_function(torch.ops.aten.convolution.default),
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor
|
||||
),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor
|
||||
),
|
||||
]
|
||||
self.checkGraphModuleNodes(convert_module,
|
||||
expected_node_occurrence=node_occurrence,
|
||||
expected_node_list=node_list)
|
||||
self.checkGraphModuleNodes(
|
||||
convert_module,
|
||||
expected_node_occurrence=node_occurrence,
|
||||
expected_node_list=node_list,
|
||||
)
|
||||
|
||||
# Step1: Ref result in 1.X fx path
|
||||
backend_config_1_x = get_x86_backend_config()
|
||||
m_copy = copy.deepcopy(m)
|
||||
m_prepare_fx = prepare_fx(m_copy, qconfig_mapping, example_inputs, backend_config=backend_config_1_x)
|
||||
m_prepare_fx = prepare_fx(
|
||||
m_copy,
|
||||
qconfig_mapping,
|
||||
example_inputs,
|
||||
backend_config=backend_config_1_x,
|
||||
)
|
||||
after_prepare_result_fx = m_prepare_fx(*example_inputs)
|
||||
m_convert_fx = convert_fx(m_prepare_fx, backend_config=backend_config_1_x)
|
||||
m_convert_fx = convert_fx(
|
||||
m_prepare_fx, backend_config=backend_config_1_x
|
||||
)
|
||||
ref_result = m_convert_fx(*example_inputs)
|
||||
|
||||
# Step2: Start to lowering into Inductor
|
||||
|
|
@ -483,11 +574,13 @@ class TestQuantizePT2EX86Inductor(QuantizationTestCase):
|
|||
inductor_res = run(*example_inputs)
|
||||
self.assertEqual(ref_result, inductor_res, atol=5e-2, rtol=5e-2)
|
||||
|
||||
|
||||
class TestQuantizePT2EModels(QuantizationTestCase):
|
||||
@skip_if_no_torchvision
|
||||
@skipIfNoQNNPACK
|
||||
def test_resnet18(self):
|
||||
import torchvision
|
||||
|
||||
with override_quantized_engine("qnnpack"):
|
||||
example_inputs = (torch.randn(1, 3, 224, 224),)
|
||||
m = torchvision.models.resnet18().eval()
|
||||
|
|
@ -510,7 +603,9 @@ class TestQuantizePT2EModels(QuantizationTestCase):
|
|||
|
||||
# checking that we inserted observers correctly for maxpool operator (input and
|
||||
# output share observer instance)
|
||||
self.assertEqual(id(m.activation_post_process_3), id(m.activation_post_process_2))
|
||||
self.assertEqual(
|
||||
id(m.activation_post_process_3), id(m.activation_post_process_2)
|
||||
)
|
||||
after_prepare_result = m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
|
||||
|
|
@ -518,7 +613,9 @@ class TestQuantizePT2EModels(QuantizationTestCase):
|
|||
|
||||
# comparing with existing fx graph mode quantization reference flow
|
||||
backend_config = get_qnnpack_backend_config()
|
||||
m_fx = prepare_fx(m_copy, qconfig_mapping, example_inputs, backend_config=backend_config)
|
||||
m_fx = prepare_fx(
|
||||
m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
|
||||
)
|
||||
after_prepare_result_fx = m_fx(*example_inputs)
|
||||
m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config)
|
||||
|
||||
|
|
@ -526,17 +623,24 @@ class TestQuantizePT2EModels(QuantizationTestCase):
|
|||
|
||||
# the result matches exactly after prepare
|
||||
self.assertEqual(after_prepare_result, after_prepare_result_fx)
|
||||
self.assertEqual(compute_sqnr(after_prepare_result, after_prepare_result_fx), torch.tensor(float("inf")))
|
||||
self.assertEqual(
|
||||
compute_sqnr(after_prepare_result, after_prepare_result_fx),
|
||||
torch.tensor(float("inf")),
|
||||
)
|
||||
# there are slight differences after convert due to different implementations
|
||||
# of quant/dequant
|
||||
self.assertTrue(torch.max(after_quant_result - after_quant_result_fx) < 1e-1)
|
||||
self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) > 35)
|
||||
|
||||
self.assertTrue(
|
||||
torch.max(after_quant_result - after_quant_result_fx) < 1e-1
|
||||
)
|
||||
self.assertTrue(
|
||||
compute_sqnr(after_quant_result, after_quant_result_fx) > 35
|
||||
)
|
||||
|
||||
@skip_if_no_torchvision
|
||||
@skipIfNoQNNPACK
|
||||
def test_resnet18_with_quantizer_api(self):
|
||||
import torchvision
|
||||
|
||||
with override_quantized_engine("qnnpack"):
|
||||
example_inputs = (torch.randn(1, 3, 224, 224),)
|
||||
m = torchvision.models.resnet18().eval()
|
||||
|
|
@ -551,13 +655,18 @@ class TestQuantizePT2EModels(QuantizationTestCase):
|
|||
|
||||
before_fusion_result = m(*example_inputs)
|
||||
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
|
||||
|
||||
quantizer = QNNPackQuantizer()
|
||||
operator_spec = qq.get_default_per_channel_symmetric_qnnpack_operator_spec()
|
||||
quantizer.set_global(operator_spec)
|
||||
operator_config = (
|
||||
qq.get_default_per_channel_symmetric_qnnpack_quantization_config()
|
||||
)
|
||||
quantizer.set_global(operator_config)
|
||||
m = prepare_pt2e_quantizer(m, quantizer)
|
||||
# checking that we inserted observers correctly for maxpool operator (input and
|
||||
# output share observer instance)
|
||||
self.assertEqual(id(m.activation_post_process_3), id(m.activation_post_process_2))
|
||||
self.assertEqual(
|
||||
id(m.activation_post_process_3), id(m.activation_post_process_2)
|
||||
)
|
||||
after_prepare_result = m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
|
||||
|
|
@ -567,7 +676,9 @@ class TestQuantizePT2EModels(QuantizationTestCase):
|
|||
qconfig = default_per_channel_symmetric_qnnpack_qconfig
|
||||
qconfig_mapping = QConfigMapping().set_global(qconfig)
|
||||
backend_config = get_qnnpack_backend_config()
|
||||
m_fx = prepare_fx(m_copy, qconfig_mapping, example_inputs, backend_config=backend_config)
|
||||
m_fx = prepare_fx(
|
||||
m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
|
||||
)
|
||||
after_prepare_result_fx = m_fx(*example_inputs)
|
||||
m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config)
|
||||
|
||||
|
|
@ -579,8 +690,15 @@ class TestQuantizePT2EModels(QuantizationTestCase):
|
|||
# but we can still manully inspect the printed observers to make sure
|
||||
# it matches
|
||||
self.assertEqual(after_prepare_result, after_prepare_result_fx)
|
||||
self.assertEqual(compute_sqnr(after_prepare_result, after_prepare_result_fx), torch.tensor(float("inf")))
|
||||
self.assertEqual(
|
||||
compute_sqnr(after_prepare_result, after_prepare_result_fx),
|
||||
torch.tensor(float("inf")),
|
||||
)
|
||||
# there are slight differences after convert due to different implementations
|
||||
# of quant/dequant
|
||||
self.assertTrue(torch.max(after_quant_result - after_quant_result_fx) < 1e-1)
|
||||
self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) > 35)
|
||||
self.assertTrue(
|
||||
torch.max(after_quant_result - after_quant_result_fx) < 1e-1
|
||||
)
|
||||
self.assertTrue(
|
||||
compute_sqnr(after_quant_result, after_quant_result_fx) > 35
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from .quantizer import Quantizer
|
||||
from .qnnpack_quantizer import QNNPackQuantizer
|
||||
from .quantizer import OperatorConfig, Quantizer
|
||||
|
||||
__all__ = [
|
||||
"Quantizer"
|
||||
"Quantizer",
|
||||
"QNNPackQuantizer",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,95 +1,80 @@
|
|||
from __future__ import annotations
|
||||
from .quantizer import Quantizer
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional, Set, Dict, Callable
|
||||
import operator
|
||||
from typing import Callable, Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.ao.quantization.observer import (
|
||||
PlaceholderObserver,
|
||||
HistogramObserver,
|
||||
MinMaxObserver,
|
||||
PerChannelMinMaxObserver,
|
||||
PlaceholderObserver,
|
||||
)
|
||||
import torch
|
||||
import operator
|
||||
from torch.fx import Node
|
||||
|
||||
from .quantizer import (
|
||||
OperatorConfig,
|
||||
OperatorPatternType,
|
||||
QuantizationConfig,
|
||||
QuantizationSpec,
|
||||
Quantizer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"QNNPackQuantizer",
|
||||
"get_default_symmetric_qnnpack_operator_spec",
|
||||
"get_default_per_channel_symmetric_qnnpack_operator_spec",
|
||||
"get_default_symmetric_qnnpack_quantization_config",
|
||||
"get_default_per_channel_symmetric_qnnpack_quantization_config",
|
||||
]
|
||||
|
||||
# TODO: maybe remove torch.float32
|
||||
SUPPORTED_DTYPES = [torch.uint8, torch.int8, torch.int32, torch.float16, torch.float32]
|
||||
SUPPORTED_QSCHEMES = [
|
||||
torch.per_tensor_affine,
|
||||
torch.per_tensor_symmetric,
|
||||
torch.per_channel_affine,
|
||||
torch.per_channel_symmetric,
|
||||
torch.per_channel_affine_float_qparams,
|
||||
]
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class TensorSpec:
|
||||
dtype: torch.dtype
|
||||
is_dynamic: bool = False
|
||||
quant_min: Optional[int] = None
|
||||
quant_max: Optional[int] = None
|
||||
qscheme: Optional[torch.qscheme] = None
|
||||
ch_axis: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# check dtype is one of the supported types
|
||||
if self.dtype not in SUPPORTED_DTYPES:
|
||||
raise TypeError(f"Unsupported dtype {self.dtype}.")
|
||||
|
||||
# quant_min must be less than quant_max
|
||||
if self.quant_min is not None and self.quant_max is not None and self.quant_min > self.quant_max:
|
||||
raise ValueError(
|
||||
f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}."
|
||||
)
|
||||
|
||||
# check qscheme is on of the supported ones
|
||||
if self.qscheme is not None and self.qscheme not in SUPPORTED_QSCHEMES:
|
||||
raise ValueError(f"Unsupported qscheme {self.qscheme}.")
|
||||
|
||||
# ch_axis must be less than the number of channels
|
||||
# but no way to check here. Just check that it is not < 0.
|
||||
if self.ch_axis is not None and self.ch_axis < 0:
|
||||
raise ValueError("Ch_axis is < 0.")
|
||||
|
||||
|
||||
OperatorSpec = NamedTuple(
|
||||
"OperatorSpec", [("activation", TensorSpec), ("weight", TensorSpec), ("bias", TensorSpec)]
|
||||
)
|
||||
|
||||
SpecAndOperators = NamedTuple(
|
||||
"SpecAndOperators",
|
||||
[("operator_spec", OperatorSpec), ("operators", List[str])]
|
||||
)
|
||||
|
||||
|
||||
def supported_symmetric_quantized_operators() -> List[str]:
|
||||
supported_operators = ["conv2d", "linear", "add", "maxpool2d", "hardtanh", "mean", "adaptive_avgpool2d"]
|
||||
def supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
|
||||
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
||||
# Both conv and linear should be able to handle relu + hardtanh fusion since
|
||||
# those are clamp ops
|
||||
"conv2d": [
|
||||
[torch.nn.Conv2d, torch.nn.ReLU],
|
||||
[torch.nn.Conv2d, F.relu],
|
||||
[F.conv2d, torch.nn.ReLU],
|
||||
[F.conv2d, F.relu],
|
||||
],
|
||||
"linear": [[torch.nn.Linear], [F.linear]],
|
||||
"add": [[torch.add]],
|
||||
"maxpool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
|
||||
"hardtanh": [[torch.nn.Hardtanh], [F.hardtanh]],
|
||||
"mean": [[torch.mean]],
|
||||
"adaptive_avgpool2d": [
|
||||
[torch.nn.AdaptiveAvgPool2d],
|
||||
[F.adaptive_avg_pool2d],
|
||||
],
|
||||
}
|
||||
return copy.deepcopy(supported_operators)
|
||||
|
||||
def get_supported_symmetric_quantized_spec_and_operators() -> List[SpecAndOperators]:
|
||||
supported_spec_and_operators: List[SpecAndOperators] = []
|
||||
for operator_spec in [get_default_symmetric_qnnpack_operator_spec(), get_default_per_channel_symmetric_qnnpack_operator_spec()]:
|
||||
ops = supported_symmetric_quantized_operators()
|
||||
supported_spec_and_operators.append(SpecAndOperators(operator_spec, ops))
|
||||
return copy.deepcopy(supported_spec_and_operators)
|
||||
|
||||
def get_default_symmetric_qnnpack_operator_spec():
|
||||
act_tensor_spec = TensorSpec(
|
||||
def get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
|
||||
supported_config_and_operators: List[OperatorConfig] = []
|
||||
for quantization_config in [
|
||||
get_default_symmetric_qnnpack_quantization_config(),
|
||||
get_default_per_channel_symmetric_qnnpack_quantization_config(),
|
||||
]:
|
||||
ops = supported_symmetric_quantized_operators()
|
||||
for op_string, pattern_list in ops.items():
|
||||
supported_config_and_operators.append(
|
||||
OperatorConfig(quantization_config, pattern_list)
|
||||
)
|
||||
return copy.deepcopy(supported_config_and_operators)
|
||||
|
||||
|
||||
def get_default_symmetric_qnnpack_quantization_config():
|
||||
act_quantization_spec = QuantizationSpec(
|
||||
dtype=torch.int8,
|
||||
quant_min=-128,
|
||||
quant_max=127,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
is_dynamic=False,
|
||||
)
|
||||
weight_tensor_spec = TensorSpec(
|
||||
weight_quantization_spec = QuantizationSpec(
|
||||
dtype=torch.int8,
|
||||
quant_min=-127,
|
||||
quant_max=127,
|
||||
|
|
@ -97,19 +82,22 @@ def get_default_symmetric_qnnpack_operator_spec():
|
|||
ch_axis=1,
|
||||
is_dynamic=False,
|
||||
)
|
||||
bias_tensor_spec = TensorSpec(dtype=torch.float)
|
||||
operator_spec = OperatorSpec(act_tensor_spec, weight_tensor_spec, bias_tensor_spec)
|
||||
return operator_spec
|
||||
bias_quantization_spec = QuantizationSpec(dtype=torch.float)
|
||||
quantization_config = QuantizationConfig(
|
||||
act_quantization_spec, weight_quantization_spec, bias_quantization_spec
|
||||
)
|
||||
return quantization_config
|
||||
|
||||
def get_default_per_channel_symmetric_qnnpack_operator_spec():
|
||||
act_tensor_spec = TensorSpec(
|
||||
|
||||
def get_default_per_channel_symmetric_qnnpack_quantization_config():
|
||||
act_quantization_spec = QuantizationSpec(
|
||||
dtype=torch.int8,
|
||||
quant_min=-128,
|
||||
quant_max=127,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
is_dynamic=False,
|
||||
)
|
||||
weight_tensor_spec = TensorSpec(
|
||||
weight_quantization_spec = QuantizationSpec(
|
||||
dtype=torch.int8,
|
||||
quant_min=-127,
|
||||
quant_max=127,
|
||||
|
|
@ -117,27 +105,16 @@ def get_default_per_channel_symmetric_qnnpack_operator_spec():
|
|||
ch_axis=1,
|
||||
is_dynamic=False,
|
||||
)
|
||||
bias_tensor_spec = TensorSpec(dtype=torch.float)
|
||||
operator_spec = OperatorSpec(act_tensor_spec, weight_tensor_spec, bias_tensor_spec)
|
||||
return operator_spec
|
||||
bias_quantization_spec = QuantizationSpec(dtype=torch.float)
|
||||
quantization_config = QuantizationConfig(
|
||||
act_quantization_spec, weight_quantization_spec, bias_quantization_spec
|
||||
)
|
||||
return quantization_config
|
||||
|
||||
def get_supported_spec_and_operators() -> List[SpecAndOperators]:
|
||||
return get_supported_symmetric_quantized_spec_and_operators()
|
||||
|
||||
class OperatorSpecConfig:
|
||||
def get_supported_config_and_operators() -> List[OperatorConfig]:
|
||||
return get_supported_symmetric_config_and_operators()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.global_spec: Optional[OperatorSpec] = None
|
||||
self.operator_type_specs: Dict[str, Optional[OperatorSpec]] = {}
|
||||
|
||||
def set_global(self, operator_spec: Optional[OperatorSpec]) -> OperatorSpecConfig:
|
||||
self.global_spec = operator_spec
|
||||
return self
|
||||
|
||||
def set_operator_type(self, operator_type: str, operator_spec: Optional[OperatorSpec]) -> OperatorSpecConfig:
|
||||
self.operator_type_specs[operator_type] = operator_spec
|
||||
return self
|
||||
|
||||
# TODO: add support for torch dtype in quant code base
|
||||
# this includes observers and prepare/convert code
|
||||
|
|
@ -147,61 +124,76 @@ _TORCH_DTYPE_TO_QDTYPE = {
|
|||
torch.int32: torch.qint32,
|
||||
torch.float16: torch.float16,
|
||||
}
|
||||
def _get_act_obs_or_fq_ctr(operator_spec: Optional[OperatorSpec]):
|
||||
if operator_spec is None:
|
||||
|
||||
|
||||
def _get_act_obs_or_fq_ctr(quantization_config: Optional[QuantizationConfig]):
|
||||
if quantization_config is None:
|
||||
return None
|
||||
assert operator_spec is not None
|
||||
tensor_spec: TensorSpec = operator_spec.activation
|
||||
qdtype = _TORCH_DTYPE_TO_QDTYPE[tensor_spec.dtype]
|
||||
assert tensor_spec.qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]
|
||||
if not tensor_spec.is_dynamic:
|
||||
assert quantization_config is not None
|
||||
quantization_spec: QuantizationSpec = quantization_config.activation
|
||||
qdtype = _TORCH_DTYPE_TO_QDTYPE[quantization_spec.dtype]
|
||||
assert quantization_spec.qscheme in [
|
||||
torch.per_tensor_affine,
|
||||
torch.per_tensor_symmetric,
|
||||
]
|
||||
if not quantization_spec.is_dynamic:
|
||||
return HistogramObserver.with_args(
|
||||
dtype=qdtype,
|
||||
quant_min=tensor_spec.quant_min,
|
||||
quant_max=tensor_spec.quant_max,
|
||||
quant_min=quantization_spec.quant_min,
|
||||
quant_max=quantization_spec.quant_max,
|
||||
reduce_range=False,
|
||||
eps=2**-12
|
||||
eps=2**-12,
|
||||
)
|
||||
else:
|
||||
# TODO: extend this helper function to support dynamic quantization
|
||||
raise Exception("Unsupported tensor_spec for activation: {}".format(tensor_spec))
|
||||
|
||||
def _get_weight_obs_or_fq_ctr(operator_spec: Optional[OperatorSpec]):
|
||||
if operator_spec is None:
|
||||
return None
|
||||
assert operator_spec is not None
|
||||
tensor_spec: TensorSpec = operator_spec.weight
|
||||
qdtype = _TORCH_DTYPE_TO_QDTYPE[tensor_spec.dtype]
|
||||
if tensor_spec.qscheme == torch.per_tensor_symmetric:
|
||||
return MinMaxObserver.with_args(
|
||||
qscheme=tensor_spec.qscheme,
|
||||
dtype=qdtype,
|
||||
quant_min=tensor_spec.quant_min,
|
||||
quant_max=tensor_spec.quant_max,
|
||||
eps=2**-12
|
||||
raise Exception(
|
||||
"Unsupported quantization_spec for activation: {}".format(quantization_spec)
|
||||
)
|
||||
elif tensor_spec.qscheme == torch.per_channel_symmetric:
|
||||
return PerChannelMinMaxObserver.with_args(
|
||||
qscheme=tensor_spec.qscheme,
|
||||
|
||||
|
||||
def _get_weight_obs_or_fq_ctr(quantization_config: Optional[QuantizationConfig]):
|
||||
if quantization_config is None:
|
||||
return None
|
||||
assert quantization_config is not None
|
||||
quantization_spec: QuantizationSpec = quantization_config.weight
|
||||
qdtype = _TORCH_DTYPE_TO_QDTYPE[quantization_spec.dtype]
|
||||
if quantization_spec.qscheme == torch.per_tensor_symmetric:
|
||||
return MinMaxObserver.with_args(
|
||||
qscheme=quantization_spec.qscheme,
|
||||
dtype=qdtype,
|
||||
quant_min=tensor_spec.quant_min,
|
||||
quant_max=tensor_spec.quant_max,
|
||||
eps=2**-12
|
||||
quant_min=quantization_spec.quant_min,
|
||||
quant_max=quantization_spec.quant_max,
|
||||
eps=2**-12,
|
||||
)
|
||||
elif quantization_spec.qscheme == torch.per_channel_symmetric:
|
||||
return PerChannelMinMaxObserver.with_args(
|
||||
qscheme=quantization_spec.qscheme,
|
||||
dtype=qdtype,
|
||||
quant_min=quantization_spec.quant_min,
|
||||
quant_max=quantization_spec.quant_max,
|
||||
eps=2**-12,
|
||||
)
|
||||
else:
|
||||
raise Exception("Unsupported tensor_spec for weight: {}".format(tensor_spec))
|
||||
raise Exception(
|
||||
"Unsupported quantization_spec for weight: {}".format(quantization_spec)
|
||||
)
|
||||
|
||||
def _get_bias_obs_or_fq_ctr(operator_spec: Optional[OperatorSpec]):
|
||||
if operator_spec is None:
|
||||
|
||||
def _get_bias_obs_or_fq_ctr(quantization_config: Optional[QuantizationConfig]):
|
||||
if quantization_config is None:
|
||||
return None
|
||||
assert operator_spec is not None
|
||||
tensor_spec: TensorSpec = operator_spec.bias
|
||||
assert tensor_spec.dtype == torch.float, "Only float dtype for bias is supported for bias right now"
|
||||
return PlaceholderObserver.with_args(dtype=tensor_spec.dtype)
|
||||
assert quantization_config is not None
|
||||
quantization_spec: QuantizationSpec = quantization_config.bias
|
||||
assert (
|
||||
quantization_spec.dtype == torch.float
|
||||
), "Only float dtype for bias is supported for bias right now"
|
||||
return PlaceholderObserver.with_args(dtype=quantization_spec.dtype)
|
||||
|
||||
|
||||
def _get_default_obs_or_fq_ctr():
|
||||
return PlaceholderObserver.with_args(dtype=torch.float)
|
||||
|
||||
|
||||
def _is_annotated(nodes: List[Node]):
|
||||
"""
|
||||
Given a list of nodes (that represents an operator pattern),
|
||||
|
|
@ -210,96 +202,106 @@ def _is_annotated(nodes: List[Node]):
|
|||
"""
|
||||
annotated = False
|
||||
for node in nodes:
|
||||
annotated = annotated or ("target_dtype_info" in node.meta and node.meta["target_dtype_info"].get("_annotated", False))
|
||||
annotated = annotated or (
|
||||
"target_dtype_info" in node.meta
|
||||
and node.meta["target_dtype_info"].get("_annotated", False)
|
||||
)
|
||||
return annotated
|
||||
|
||||
|
||||
class QNNPackQuantizer(Quantizer):
|
||||
supported_spec_and_operators = get_supported_spec_and_operators()
|
||||
supported_config_and_operators = get_supported_config_and_operators()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.operator_spec_config = OperatorSpecConfig()
|
||||
self.global_config: Optional[QuantizationConfig] = None
|
||||
self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_supported_operator_specs(cls) -> List[OperatorSpec]:
|
||||
op_specs: Set[OperatorSpec] = set({})
|
||||
for spec, _ in cls.supported_spec_and_operators:
|
||||
op_specs.add(spec)
|
||||
return list(op_specs)
|
||||
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
|
||||
op_configs: Set[QuantizationConfig] = set({})
|
||||
for spec, _ in cls.supported_config_and_operators:
|
||||
op_configs.add(spec)
|
||||
return list(op_configs)
|
||||
|
||||
@classmethod
|
||||
def get_supported_operator_for_operator_spec(cls, operator_spec: Optional[OperatorSpec]) -> List[str]:
|
||||
if operator_spec is None:
|
||||
def get_supported_operator_for_quantization_config(
|
||||
cls, quantization_config: Optional[QuantizationConfig]
|
||||
) -> List[OperatorPatternType]:
|
||||
if quantization_config is None:
|
||||
all_ops = []
|
||||
for _, ops in cls.supported_spec_and_operators:
|
||||
for _, ops in cls.supported_config_and_operators:
|
||||
all_ops.extend(ops)
|
||||
return all_ops
|
||||
|
||||
for spec, ops in cls.supported_spec_and_operators:
|
||||
for config, ops in cls.supported_config_and_operators:
|
||||
# note: this assumes each entry in cls.supported_spec_and_operators
|
||||
# corresponds to one spec, e.g. we don't have
|
||||
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
|
||||
# where the first and second entry have the same spec but did not
|
||||
# merge the op list
|
||||
if spec == operator_spec:
|
||||
if config == quantization_config:
|
||||
return ops
|
||||
return []
|
||||
|
||||
def set_global(self, operator_spec: Optional[OperatorSpec]) -> QNNPackQuantizer:
|
||||
self.operator_spec_config.set_global(operator_spec)
|
||||
def set_global(
|
||||
self, quantization_config: Optional[QuantizationConfig]
|
||||
) -> QNNPackQuantizer:
|
||||
self.global_config = quantization_config
|
||||
return self
|
||||
|
||||
def set_spec_for_operator_type(
|
||||
self, operator_type: str, operator_spec: Optional[OperatorSpec]
|
||||
def set_config_for_operator_type(
|
||||
self, operator_type: str, quantization_config: Optional[QuantizationConfig]
|
||||
) -> QNNPackQuantizer:
|
||||
self.operator_spec_config.set_operator_type(operator_type, operator_spec)
|
||||
self.operator_type_config[operator_type] = quantization_config
|
||||
return self
|
||||
|
||||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
""" just handling global spec for now
|
||||
"""
|
||||
global_spec = self.operator_spec_config.global_spec
|
||||
ops = self.get_supported_operator_for_operator_spec(global_spec)
|
||||
"""just handling global spec for now"""
|
||||
global_config = self.global_config
|
||||
ops = self.get_supported_operator_for_quantization_config(global_config)
|
||||
# annotate the nodes from last to first since the matching is in the reversed order
|
||||
# and fusion operator patterns (conv - relu) can get matched before single operator pattern (conv)
|
||||
# and we will mark the matched node with "_annoated" so fusion operator pattern
|
||||
# can take precedence over single operator pattern in this way
|
||||
for node in reversed(model.graph.nodes):
|
||||
for op in ops:
|
||||
if op == "conv2d":
|
||||
self._annotate_conv2d_relu(node, global_spec)
|
||||
self._annotate_conv2d(node, global_spec)
|
||||
elif op == "linear":
|
||||
self._annotate_linear(node, global_spec)
|
||||
elif op == "maxpool2d":
|
||||
self._annotate_maxpool2d(node, global_spec)
|
||||
elif op == "add":
|
||||
self._annotate_add_relu(node, global_spec)
|
||||
self._annotate_add(node, global_spec)
|
||||
elif op == "hardtanh":
|
||||
self._annotate_hardtanh(node, global_spec)
|
||||
elif op == "mean":
|
||||
self._annotate_mean(node, global_spec)
|
||||
elif op == "adaptive_avgpool2d":
|
||||
self._annotate_adaptive_avg_pool2d(node, global_spec)
|
||||
# one improvement is to register node annotators for each
|
||||
# supported op type.
|
||||
self._annotate_conv2d_relu(node, global_config)
|
||||
self._annotate_conv2d(node, global_config)
|
||||
self._annotate_linear(node, global_config)
|
||||
self._annotate_maxpool2d(node, global_config)
|
||||
self._annotate_add_relu(node, global_config)
|
||||
self._annotate_add(node, global_config)
|
||||
self._annotate_hardtanh(node, global_config)
|
||||
self._annotate_mean(node, global_config)
|
||||
self._annotate_adaptive_avg_pool2d(node, global_config)
|
||||
|
||||
return model
|
||||
|
||||
def _annotate_conv2d_relu(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
|
||||
if node.op != "call_function" or node.target not in [torch.ops.aten.relu_.default, torch.ops.aten.relu.default]:
|
||||
def _annotate_conv2d_relu(
|
||||
self, node: Node, quantization_config: Optional[QuantizationConfig]
|
||||
) -> None:
|
||||
if node.op != "call_function" or node.target not in [
|
||||
torch.ops.aten.relu_.default,
|
||||
torch.ops.aten.relu.default,
|
||||
]:
|
||||
return
|
||||
relu_node = node
|
||||
conv_node = relu_node.args[0]
|
||||
assert isinstance(conv_node, Node)
|
||||
if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default:
|
||||
if (
|
||||
conv_node.op != "call_function"
|
||||
or conv_node.target != torch.ops.aten.convolution.default
|
||||
):
|
||||
return
|
||||
if _is_annotated([relu_node, conv_node]):
|
||||
return
|
||||
|
||||
conv_node.meta["target_dtype_info"] = {
|
||||
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
|
||||
"weight_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(operator_spec),
|
||||
"bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(operator_spec),
|
||||
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
|
||||
"weight_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(quantization_config),
|
||||
"bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(quantization_config),
|
||||
"output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
|
||||
# TODO: validation of weight_index must be set if weight_obs_or_fq_ctr is set
|
||||
"weight_index": 1,
|
||||
|
|
@ -309,22 +311,27 @@ class QNNPackQuantizer(Quantizer):
|
|||
}
|
||||
relu_node.meta["target_dtype_info"] = {
|
||||
"input_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
|
||||
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
|
||||
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
|
||||
"_annotated": True,
|
||||
}
|
||||
|
||||
def _annotate_conv2d(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
|
||||
def _annotate_conv2d(
|
||||
self, node: Node, quantization_config: Optional[QuantizationConfig]
|
||||
) -> None:
|
||||
conv_node = node
|
||||
if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default:
|
||||
if (
|
||||
conv_node.op != "call_function"
|
||||
or conv_node.target != torch.ops.aten.convolution.default
|
||||
):
|
||||
return
|
||||
# skip annotation if it is already annotated
|
||||
if _is_annotated([conv_node]):
|
||||
return
|
||||
conv_node.meta["target_dtype_info"] = {
|
||||
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
|
||||
"weight_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(operator_spec),
|
||||
"bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(operator_spec),
|
||||
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
|
||||
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
|
||||
"weight_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(quantization_config),
|
||||
"bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(quantization_config),
|
||||
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
|
||||
# TODO: validation of weight_index must be set if weight_obs_or_fq_ctr is set
|
||||
"weight_index": 1,
|
||||
# TODO: validation of bias_index must be set if bias_obs_or_fq_ctr is set
|
||||
|
|
@ -332,13 +339,21 @@ class QNNPackQuantizer(Quantizer):
|
|||
"_annotated": True,
|
||||
}
|
||||
|
||||
def _annotate_linear(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
|
||||
def _annotate_linear(
|
||||
self, node: Node, quantization_config: Optional[QuantizationConfig]
|
||||
) -> None:
|
||||
addmm_node = node
|
||||
if addmm_node.op != "call_function" or addmm_node.target != torch.ops.aten.addmm.default:
|
||||
if (
|
||||
addmm_node.op != "call_function"
|
||||
or addmm_node.target != torch.ops.aten.addmm.default
|
||||
):
|
||||
return
|
||||
view_node = addmm_node.args[1]
|
||||
assert isinstance(view_node, Node)
|
||||
if view_node.op != "call_function" or view_node.target != torch.ops.aten.view.default:
|
||||
if (
|
||||
view_node.op != "call_function"
|
||||
or view_node.target != torch.ops.aten.view.default
|
||||
):
|
||||
return
|
||||
t_node = addmm_node.args[2]
|
||||
assert isinstance(t_node, Node)
|
||||
|
|
@ -349,106 +364,154 @@ class QNNPackQuantizer(Quantizer):
|
|||
|
||||
# bias and output act
|
||||
addmm_node.meta["target_dtype_info"] = {
|
||||
"bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(operator_spec),
|
||||
"bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(quantization_config),
|
||||
"input_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
|
||||
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
|
||||
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
|
||||
"bias_index": 0,
|
||||
"_annotated": True,
|
||||
}
|
||||
# input act
|
||||
view_node.meta["target_dtype_info"] = {
|
||||
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
|
||||
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
|
||||
"output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
|
||||
"_annotated": True,
|
||||
}
|
||||
# weight
|
||||
t_node.meta["target_dtype_info"] = {
|
||||
"input_act_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(operator_spec),
|
||||
"input_act_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(quantization_config),
|
||||
"output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
|
||||
"_annotated": True,
|
||||
}
|
||||
|
||||
def _annotate_maxpool2d(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
|
||||
if node.op != "call_function" or node.target != operator.getitem or node.args[1] != 0:
|
||||
def _annotate_maxpool2d(
|
||||
self, node: Node, quantization_config: Optional[QuantizationConfig]
|
||||
) -> None:
|
||||
if (
|
||||
node.op != "call_function"
|
||||
or node.target != operator.getitem
|
||||
or node.args[1] != 0
|
||||
):
|
||||
return
|
||||
getitem_node = node
|
||||
maxpool_node = getitem_node.args[0]
|
||||
assert isinstance(maxpool_node, Node)
|
||||
if maxpool_node.op != "call_function" or maxpool_node.target != torch.ops.aten.max_pool2d_with_indices.default:
|
||||
if (
|
||||
maxpool_node.op != "call_function"
|
||||
or maxpool_node.target != torch.ops.aten.max_pool2d_with_indices.default
|
||||
):
|
||||
return
|
||||
if _is_annotated([getitem_node, maxpool_node]):
|
||||
return
|
||||
|
||||
maxpool_node.meta["target_dtype_info"] = {
|
||||
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
|
||||
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
|
||||
"output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
|
||||
"_annotated": True,
|
||||
}
|
||||
getitem_node.meta["target_dtype_info"] = {
|
||||
"input_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
|
||||
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
|
||||
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
|
||||
"input_output_share_observers": True,
|
||||
"_annotated": True,
|
||||
}
|
||||
|
||||
def _annotate_input_out_obs_sharing_op(self, op: Callable, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
|
||||
def _annotate_input_out_obs_sharing_op(
|
||||
self,
|
||||
op: Callable,
|
||||
node: Node,
|
||||
quantization_config: Optional[QuantizationConfig],
|
||||
) -> None:
|
||||
io_obs_sharing_node = node
|
||||
if io_obs_sharing_node.op != "call_function" or io_obs_sharing_node.target != op:
|
||||
if (
|
||||
io_obs_sharing_node.op != "call_function"
|
||||
or io_obs_sharing_node.target != op
|
||||
):
|
||||
return
|
||||
if _is_annotated([io_obs_sharing_node]):
|
||||
return
|
||||
|
||||
io_obs_sharing_node.meta["target_dtype_info"] = {
|
||||
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
|
||||
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
|
||||
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
|
||||
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
|
||||
"input_output_share_observers": True,
|
||||
"_annotated": True,
|
||||
}
|
||||
|
||||
def _annotate_hardtanh(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
|
||||
self._annotate_input_out_obs_sharing_op(torch.ops.aten.hardtanh.default, node, operator_spec)
|
||||
def _annotate_hardtanh(
|
||||
self, node: Node, quantization_config: Optional[QuantizationConfig]
|
||||
) -> None:
|
||||
self._annotate_input_out_obs_sharing_op(
|
||||
torch.ops.aten.hardtanh.default, node, quantization_config
|
||||
)
|
||||
|
||||
def _annotate_mean(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
|
||||
self._annotate_input_out_obs_sharing_op(torch.ops.aten.mean.default, node, operator_spec)
|
||||
self._annotate_input_out_obs_sharing_op(torch.ops.aten.mean.dim, node, operator_spec)
|
||||
def _annotate_mean(
|
||||
self, node: Node, quantization_config: Optional[QuantizationConfig]
|
||||
) -> None:
|
||||
self._annotate_input_out_obs_sharing_op(
|
||||
torch.ops.aten.mean.default, node, quantization_config
|
||||
)
|
||||
self._annotate_input_out_obs_sharing_op(
|
||||
torch.ops.aten.mean.dim, node, quantization_config
|
||||
)
|
||||
|
||||
def _annotate_adaptive_avg_pool2d(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
|
||||
self._annotate_input_out_obs_sharing_op(torch.ops.aten.adaptive_avg_pool2d.default, node, operator_spec)
|
||||
def _annotate_adaptive_avg_pool2d(
|
||||
self, node: Node, quantization_config: Optional[QuantizationConfig]
|
||||
) -> None:
|
||||
self._annotate_input_out_obs_sharing_op(
|
||||
torch.ops.aten.adaptive_avg_pool2d.default, node, quantization_config
|
||||
)
|
||||
|
||||
def _annotate_add_relu(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
|
||||
if node.op != "call_function" or node.target not in [torch.ops.aten.relu_.default, torch.ops.aten.relu.default]:
|
||||
def _annotate_add_relu(
|
||||
self, node: Node, quantization_config: Optional[QuantizationConfig]
|
||||
) -> None:
|
||||
if node.op != "call_function" or node.target not in [
|
||||
torch.ops.aten.relu_.default,
|
||||
torch.ops.aten.relu.default,
|
||||
]:
|
||||
return
|
||||
relu_node = node
|
||||
add_node = relu_node.args[0]
|
||||
assert isinstance(add_node, Node)
|
||||
if add_node.op != "call_function" or add_node.target not in [torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor]:
|
||||
if add_node.op != "call_function" or add_node.target not in [
|
||||
torch.ops.aten.add.Tensor,
|
||||
torch.ops.aten.add_.Tensor,
|
||||
]:
|
||||
return
|
||||
if _is_annotated([relu_node, add_node]):
|
||||
return
|
||||
|
||||
add_node.meta["target_dtype_info"] = {
|
||||
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
|
||||
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
|
||||
"output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
|
||||
"_annotated": True,
|
||||
}
|
||||
relu_node.meta["target_dtype_info"] = {
|
||||
"input_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
|
||||
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
|
||||
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
|
||||
"_annotated": True,
|
||||
}
|
||||
|
||||
def _annotate_add(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
|
||||
def _annotate_add(
|
||||
self, node: Node, quantization_config: Optional[QuantizationConfig]
|
||||
) -> None:
|
||||
add_node = node
|
||||
if add_node.op != "call_function" or add_node.target not in [torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor]:
|
||||
if add_node.op != "call_function" or add_node.target not in [
|
||||
torch.ops.aten.add.Tensor,
|
||||
torch.ops.aten.add_.Tensor,
|
||||
]:
|
||||
return
|
||||
if _is_annotated([add_node]):
|
||||
return
|
||||
|
||||
add_node.meta["target_dtype_info"] = {
|
||||
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
|
||||
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
|
||||
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
|
||||
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
|
||||
"_annotated": True,
|
||||
}
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_supported_operators(cls) -> List[OperatorConfig]:
|
||||
return cls.supported_config_and_operators
|
||||
|
|
|
|||
|
|
@ -1,11 +1,90 @@
|
|||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"Quantizer",
|
||||
]
|
||||
|
||||
# TODO: maybe remove torch.float32
|
||||
SUPPORTED_DTYPES = [torch.uint8, torch.int8, torch.int32, torch.float16, torch.float32]
|
||||
SUPPORTED_QSCHEMES = [
|
||||
torch.per_tensor_affine,
|
||||
torch.per_tensor_symmetric,
|
||||
torch.per_channel_affine,
|
||||
torch.per_channel_symmetric,
|
||||
torch.per_channel_affine_float_qparams,
|
||||
]
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class QuantizationSpec:
|
||||
dtype: torch.dtype
|
||||
is_dynamic: bool = False
|
||||
quant_min: Optional[int] = None
|
||||
quant_max: Optional[int] = None
|
||||
qscheme: Optional[torch.qscheme] = None
|
||||
ch_axis: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# check dtype is one of the supported types
|
||||
if self.dtype not in SUPPORTED_DTYPES:
|
||||
raise TypeError(f"Unsupported dtype {self.dtype}.")
|
||||
|
||||
# quant_min must be less than quant_max
|
||||
if (
|
||||
self.quant_min is not None
|
||||
and self.quant_max is not None
|
||||
and self.quant_min > self.quant_max
|
||||
):
|
||||
raise ValueError(
|
||||
f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}."
|
||||
)
|
||||
|
||||
# check qscheme is on of the supported ones
|
||||
if self.qscheme is not None and self.qscheme not in SUPPORTED_QSCHEMES:
|
||||
raise ValueError(f"Unsupported qscheme {self.qscheme}.")
|
||||
|
||||
# ch_axis must be less than the number of channels
|
||||
# but no way to check here. Just check that it is not < 0.
|
||||
if self.ch_axis is not None and self.ch_axis < 0:
|
||||
raise ValueError("Ch_axis is < 0.")
|
||||
|
||||
|
||||
# In the absence of better name, just winging it with QuantizationConfig
|
||||
QuantizationConfig = NamedTuple(
|
||||
"QuantizationConfig",
|
||||
[
|
||||
("activation", QuantizationSpec),
|
||||
("weight", QuantizationSpec),
|
||||
("bias", QuantizationSpec),
|
||||
],
|
||||
)
|
||||
|
||||
OperatorPatternType = List[Callable]
|
||||
|
||||
OperatorConfig = NamedTuple(
|
||||
"OperatorConfig",
|
||||
# fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]]
|
||||
# Basically we are mapping a quantization config to some list of patterns.
|
||||
# a pattern is defined as a list of nn module, function or builtin function names
|
||||
# e.g. [nn.Conv2d, torch.relu, torch.add]
|
||||
# We have not resolved whether fusion can be considered internal details of the
|
||||
# quantizer hence it does not need communication to user.
|
||||
# Note this pattern is not really informative since it does not really
|
||||
# tell us the graph structure resulting from the list of ops.
|
||||
[
|
||||
("config", QuantizationConfig),
|
||||
(
|
||||
"operators",
|
||||
List[OperatorPatternType],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class Quantizer(ABC):
|
||||
|
||||
# annotate nodes in the graph with observer or fake quant constructors
|
||||
|
|
@ -18,3 +97,10 @@ class Quantizer(ABC):
|
|||
@abstractmethod
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
# annotate nodes in the graph with observer or fake quant constructors
|
||||
# to convey the desired way of quantization
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_supported_operators(cls) -> List[OperatorConfig]:
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user