[PT2E][Quantization] Refactor Quantizer and QNNPACKQuantizer (#99063)

This diff renames quantization spec/config and operator config. It moves these
datastructures to base quantizer.
Base quantizer API now has get_supported_operators that returns list of
patterns that a quantizer quantizes.
There are two choices being debated for how to convey to user what a particular
quantizer will quantize.

1. Modules. We just convey what nn.Modules will be quantized. Of course that
does not mean that equivalent functional variants wont be quantized, however
for simplifity we just use nn.Module. If certain ops are quatnzied in fused
manner then that will considered internal details. Pros and cons of this
approach
pros:
  - Simple. Only nn Modules are listed.
  - User does not have to see fusion patterns.
Cons:
  - confusing perhaps because it is not clear if supported = nn.Conv2d also
    means that the quantizer supported functional.conv2d
  - Hiding fusion pattern means user has no say in not fusing. Meaning if
    conv2d + relu is fused and user configures to quantize only conv, quantizer
    will also quantize the following relu as if conv2d + relu are fused.

2. Patterns. Be explicit about what is supported and enumerate all possible
compbinations.
Pros:
  - it is very clear what quantizer will do. no surprises.
Cons:
  - It is not simple to parse.
  - It can be argued taht fusion is internal detail of the quantizer. So some
    quantizer implementation may chose to expose fusion patterns, while others
    may not and may not even provide any configurability.

One option is to move set_supported_operators/modules out of base quantizer and
let each quantizer define its own way of communicating what is supported. Issue
with this is that when we want to "Compose" multiple quantizers there is no way
for user to define the order of composition if user does not know what a
quantizer supports. For exampl quantizer A may quantizer conv + relu while B
only conv, but B's implementation is fast. In that case you may compose (B, A)
such B quantizes conv and A quantizes relu. Not knowning what A
and B support, makes such composition harder

Differential Revision: [D44895547](https://our.internmc.facebook.com/intern/diff/D44895547/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D44895547/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99063
Approved by: https://github.com/jerryzh168
This commit is contained in:
Kimish Patel 2023-04-14 19:00:10 -07:00 committed by PyTorch MergeBot
parent 888c65b6a4
commit 31f311a816
4 changed files with 570 additions and 303 deletions

View File

@ -1,39 +1,46 @@
# Owner(s): ["oncall: quantization"]
import copy
import itertools
from typing import List
import torch
import torch.nn as nn
import torch._dynamo as torchdynamo
import torch.nn as nn
from torch._inductor.compile_fx import compile_fx
from torch.ao.ns.fx.utils import compute_sqnr
from torch.ao.quantization import get_default_qconfig, observer, QConfigMapping
from torch.ao.quantization._pt2e.quantizer import (
OperatorConfig,
QNNPackQuantizer,
Quantizer,
)
from torch.ao.quantization._quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_pt2e_quantizer,
)
from torch.ao.quantization.backend_config import get_qnnpack_backend_config
from torch.ao.quantization.backend_config._qnnpack_pt2e import (
get_qnnpack_pt2e_backend_config,
)
from torch.ao.quantization.backend_config._x86_inductor_pt2e import (
get_x86_inductor_pt2e_backend_config,
)
from torch.ao.quantization.backend_config.x86 import get_x86_backend_config
from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig
from torch.ao.quantization.quantize_fx import (
convert_fx,
convert_to_reference_fx,
prepare_fx,
)
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
skip_if_no_torchvision,
skipIfNoQNNPACK,
skipIfNoX86,
)
from torch.testing._internal.common_quantization import NodeSpec as ns
from torch.testing._internal.common_quantized import (
override_quantized_engine,
)
from torch.ao.quantization import (
get_default_qconfig,
QConfigMapping,
observer,
)
from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig
from torch.ao.quantization.backend_config import (
get_qnnpack_backend_config,
)
from torch.ao.quantization.backend_config._qnnpack_pt2e import get_qnnpack_pt2e_backend_config
from torch.ao.quantization.backend_config._x86_inductor_pt2e import get_x86_inductor_pt2e_backend_config
from torch.ao.quantization.backend_config.x86 import get_x86_backend_config
from torch.ao.quantization.quantize_fx import prepare_fx, convert_to_reference_fx, convert_fx
from torch.ao.quantization._pt2e.quantizer import Quantizer
from torch.ao.quantization._pt2e.quantizer import QNNPackQuantizer
from torch.ao.quantization._quantize_pt2e import prepare_pt2e, convert_pt2e, prepare_pt2e_quantizer
from torch.ao.ns.fx.utils import (
compute_sqnr,
)
import copy
import itertools
from torch._inductor.compile_fx import compile_fx
from torch.testing._internal.common_quantized import override_quantized_engine
@skipIfNoQNNPACK
@ -62,8 +69,9 @@ class TestQuantizePT2E(QuantizationTestCase):
)
qconfig = get_default_qconfig("qnnpack")
qconfig_mapping = QConfigMapping().set_global(qconfig) \
.set_module_name("conv2", None)
qconfig_mapping = (
QConfigMapping().set_global(qconfig).set_module_name("conv2", None)
)
backend_config = get_qnnpack_pt2e_backend_config()
m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
m(*example_inputs)
@ -74,7 +82,9 @@ class TestQuantizePT2E(QuantizationTestCase):
node_occurrence = {
# two for input of the first conv, one for output for the first conv
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 3,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor
): 3,
}
node_list = [
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
@ -84,7 +94,10 @@ class TestQuantizePT2E(QuantizationTestCase):
ns.call_function(torch.ops.aten.convolution.default),
]
self.checkGraphModuleNodes(
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence)
m,
expected_node_list=node_list,
expected_node_occurrence=node_occurrence,
)
def test_qconfig_module_type(self):
class M(torch.nn.Module):
@ -122,7 +135,9 @@ class TestQuantizePT2E(QuantizationTestCase):
node_occurrence = {
# two for input and weight of the conv, one for output for the conv
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 3,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor
): 3,
}
node_list = [
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
@ -145,17 +160,28 @@ class TestQuantizePT2E(QuantizationTestCase):
class BackendAQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
_DEFAULT_TARGET_DTYPE_INFO = {
"input_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(dtype=torch.float),
"output_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(dtype=torch.float),
"input_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(
dtype=torch.float
),
"output_act_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(
dtype=torch.float
),
}
for node in model.graph.nodes:
node.meta["target_dtype_info"] = copy.deepcopy(_DEFAULT_TARGET_DTYPE_INFO)
node.meta["target_dtype_info"] = copy.deepcopy(
_DEFAULT_TARGET_DTYPE_INFO
)
for node in model.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.convolution.default:
if (
node.op == "call_function"
and node.target == torch.ops.aten.convolution.default
):
node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": observer.default_observer,
"weight_obs_or_fq_ctr": observer.default_weight_observer,
"bias_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(dtype=torch.float),
"bias_obs_or_fq_ctr": observer.PlaceholderObserver.with_args(
dtype=torch.float
),
"output_act_obs_or_fq_ctr": observer.default_observer,
"weight_index": 1,
"bias_index": 2,
@ -164,6 +190,10 @@ class TestQuantizePT2E(QuantizationTestCase):
def validate(self, model: torch.fx.GraphModule) -> None:
pass
@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
pass
m = M().eval()
example_inputs = (torch.randn(1, 3, 5, 5),)
@ -189,7 +219,8 @@ class TestQuantizePT2E(QuantizationTestCase):
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
]
self.checkGraphModuleNodes(
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence)
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)
def test_qnnpack_quantizer_conv(self):
class M(torch.nn.Module):
@ -201,9 +232,12 @@ class TestQuantizePT2E(QuantizationTestCase):
return self.conv(x)
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
operator_spec = qq.get_default_per_channel_symmetric_qnnpack_operator_spec()
quantizer.set_global(operator_spec)
operator_config = (
qq.get_default_per_channel_symmetric_qnnpack_quantization_config()
)
quantizer.set_global(operator_config)
m = M().eval()
example_inputs = (torch.randn(1, 3, 5, 5),)
@ -232,7 +266,8 @@ class TestQuantizePT2E(QuantizationTestCase):
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
]
self.checkGraphModuleNodes(
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence)
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)
def test_qnnpack_quantizer_obs_sharing_ops(self):
class M(torch.nn.Module):
@ -250,9 +285,12 @@ class TestQuantizePT2E(QuantizationTestCase):
return x
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
operator_spec = qq.get_default_per_channel_symmetric_qnnpack_operator_spec()
quantizer.set_global(operator_spec)
operator_config = (
qq.get_default_per_channel_symmetric_qnnpack_quantization_config()
)
quantizer.set_global(operator_config)
m = M().eval()
example_inputs = (torch.randn(1, 3, 5, 5),)
@ -279,11 +317,9 @@ class TestQuantizePT2E(QuantizationTestCase):
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.mean.dim),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.hardtanh.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
@ -293,7 +329,8 @@ class TestQuantizePT2E(QuantizationTestCase):
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
]
self.checkGraphModuleNodes(
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence)
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
)
def test_rearrange_weight_observer_for_decomposed_linear(self):
"""
@ -305,6 +342,7 @@ class TestQuantizePT2E(QuantizationTestCase):
weight - observer - t \
input - observer - addmm/mm
"""
class M(torch.nn.Module):
def __init__(self, with_bias, use_relu):
super().__init__()
@ -331,7 +369,7 @@ class TestQuantizePT2E(QuantizationTestCase):
tracing_mode="real",
)
qconfig = get_default_qconfig('qnnpack')
qconfig = get_default_qconfig("qnnpack")
qconfig_mapping = QConfigMapping().set_global(qconfig)
backend_config = get_qnnpack_pt2e_backend_config()
m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
@ -339,12 +377,19 @@ class TestQuantizePT2E(QuantizationTestCase):
# 1. Check graph nodes:
# - args[0] of t should be the weight observer
# - args[-1] of addmm/mm should be t
error_msg = 'Weight observer is not correctly rearranged for decomposed linear'
error_msg = (
"Weight observer is not correctly rearranged for decomposed linear"
)
for node in m.graph.nodes:
if node.target == torch.ops.aten.t.default:
target = node.args[0].target
self.assertTrue(isinstance(getattr(m, target), observer.ObserverBase), error_msg)
elif node.target in (torch.ops.aten.addmm.default, torch.ops.aten.mm.default):
self.assertTrue(
isinstance(getattr(m, target), observer.ObserverBase), error_msg
)
elif node.target in (
torch.ops.aten.addmm.default,
torch.ops.aten.mm.default,
):
target = node.args[-1].target
self.assertTrue(target == torch.ops.aten.t.default, error_msg)
@ -380,7 +425,9 @@ class TestQuantizePT2E(QuantizationTestCase):
node_occurrence = {
ns.call_function(torch.ops.aten.convolution.default): 1,
ns.call_function(torch.ops.aten._native_batch_norm_legit_no_training.default): 1,
ns.call_function(
torch.ops.aten._native_batch_norm_legit_no_training.default
): 1,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
@ -394,10 +441,13 @@ class TestQuantizePT2E(QuantizationTestCase):
# make sure bn is fused into conv
node_occurrence = {
ns.call_function(torch.ops.aten.convolution.default): 1,
ns.call_function(torch.ops.aten._native_batch_norm_legit_no_training.default): 0,
ns.call_function(
torch.ops.aten._native_batch_norm_legit_no_training.default
): 0,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
@skipIfNoQNNPACK
class TestQuantizePT2EX86Inductor(QuantizationTestCase):
@skipIfNoX86
@ -417,7 +467,9 @@ class TestQuantizePT2EX86Inductor(QuantizationTestCase):
inplace_relu_list = [True, False]
with override_quantized_engine("x86"):
with torch.no_grad():
for use_relu, inplace_relu in itertools.product(use_relu_list, inplace_relu_list):
for use_relu, inplace_relu in itertools.product(
use_relu_list, inplace_relu_list
):
m = M(use_relu=use_relu, inplace_relu=inplace_relu).eval()
example_inputs = (torch.randn(2, 3, 4, 4),)
# program capture
@ -433,7 +485,9 @@ class TestQuantizePT2EX86Inductor(QuantizationTestCase):
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)
backend_config = get_x86_inductor_pt2e_backend_config()
prepare_module = prepare_pt2e(export_module, qconfig_mapping, example_inputs, backend_config)
prepare_module = prepare_pt2e(
export_module, qconfig_mapping, example_inputs, backend_config
)
prepare_module(*example_inputs)
convert_module = convert_pt2e(prepare_module)
convert_module(*example_inputs)
@ -441,38 +495,75 @@ class TestQuantizePT2EX86Inductor(QuantizationTestCase):
# Fake quant should only be inserted at start and end
node_occurrence = {
# one for input and weight of the conv, one for output for the conv
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 2,
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel): 1,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel): 1,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 2,
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor
): 2,
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_channel
): 1,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel
): 1,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor
): 2,
}
if use_relu:
node_list = [
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor
),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(
torch.ops.aten.relu_.default
if inplace_relu
else torch.ops.aten.relu.default
),
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor
),
]
else:
node_list = [
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor
),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor
),
]
self.checkGraphModuleNodes(convert_module,
expected_node_occurrence=node_occurrence,
expected_node_list=node_list)
self.checkGraphModuleNodes(
convert_module,
expected_node_occurrence=node_occurrence,
expected_node_list=node_list,
)
# Step1: Ref result in 1.X fx path
backend_config_1_x = get_x86_backend_config()
m_copy = copy.deepcopy(m)
m_prepare_fx = prepare_fx(m_copy, qconfig_mapping, example_inputs, backend_config=backend_config_1_x)
m_prepare_fx = prepare_fx(
m_copy,
qconfig_mapping,
example_inputs,
backend_config=backend_config_1_x,
)
after_prepare_result_fx = m_prepare_fx(*example_inputs)
m_convert_fx = convert_fx(m_prepare_fx, backend_config=backend_config_1_x)
m_convert_fx = convert_fx(
m_prepare_fx, backend_config=backend_config_1_x
)
ref_result = m_convert_fx(*example_inputs)
# Step2: Start to lowering into Inductor
@ -483,11 +574,13 @@ class TestQuantizePT2EX86Inductor(QuantizationTestCase):
inductor_res = run(*example_inputs)
self.assertEqual(ref_result, inductor_res, atol=5e-2, rtol=5e-2)
class TestQuantizePT2EModels(QuantizationTestCase):
@skip_if_no_torchvision
@skipIfNoQNNPACK
def test_resnet18(self):
import torchvision
with override_quantized_engine("qnnpack"):
example_inputs = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18().eval()
@ -510,7 +603,9 @@ class TestQuantizePT2EModels(QuantizationTestCase):
# checking that we inserted observers correctly for maxpool operator (input and
# output share observer instance)
self.assertEqual(id(m.activation_post_process_3), id(m.activation_post_process_2))
self.assertEqual(
id(m.activation_post_process_3), id(m.activation_post_process_2)
)
after_prepare_result = m(*example_inputs)
m = convert_pt2e(m)
@ -518,7 +613,9 @@ class TestQuantizePT2EModels(QuantizationTestCase):
# comparing with existing fx graph mode quantization reference flow
backend_config = get_qnnpack_backend_config()
m_fx = prepare_fx(m_copy, qconfig_mapping, example_inputs, backend_config=backend_config)
m_fx = prepare_fx(
m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
)
after_prepare_result_fx = m_fx(*example_inputs)
m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config)
@ -526,17 +623,24 @@ class TestQuantizePT2EModels(QuantizationTestCase):
# the result matches exactly after prepare
self.assertEqual(after_prepare_result, after_prepare_result_fx)
self.assertEqual(compute_sqnr(after_prepare_result, after_prepare_result_fx), torch.tensor(float("inf")))
self.assertEqual(
compute_sqnr(after_prepare_result, after_prepare_result_fx),
torch.tensor(float("inf")),
)
# there are slight differences after convert due to different implementations
# of quant/dequant
self.assertTrue(torch.max(after_quant_result - after_quant_result_fx) < 1e-1)
self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) > 35)
self.assertTrue(
torch.max(after_quant_result - after_quant_result_fx) < 1e-1
)
self.assertTrue(
compute_sqnr(after_quant_result, after_quant_result_fx) > 35
)
@skip_if_no_torchvision
@skipIfNoQNNPACK
def test_resnet18_with_quantizer_api(self):
import torchvision
with override_quantized_engine("qnnpack"):
example_inputs = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18().eval()
@ -551,13 +655,18 @@ class TestQuantizePT2EModels(QuantizationTestCase):
before_fusion_result = m(*example_inputs)
import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq
quantizer = QNNPackQuantizer()
operator_spec = qq.get_default_per_channel_symmetric_qnnpack_operator_spec()
quantizer.set_global(operator_spec)
operator_config = (
qq.get_default_per_channel_symmetric_qnnpack_quantization_config()
)
quantizer.set_global(operator_config)
m = prepare_pt2e_quantizer(m, quantizer)
# checking that we inserted observers correctly for maxpool operator (input and
# output share observer instance)
self.assertEqual(id(m.activation_post_process_3), id(m.activation_post_process_2))
self.assertEqual(
id(m.activation_post_process_3), id(m.activation_post_process_2)
)
after_prepare_result = m(*example_inputs)
m = convert_pt2e(m)
@ -567,7 +676,9 @@ class TestQuantizePT2EModels(QuantizationTestCase):
qconfig = default_per_channel_symmetric_qnnpack_qconfig
qconfig_mapping = QConfigMapping().set_global(qconfig)
backend_config = get_qnnpack_backend_config()
m_fx = prepare_fx(m_copy, qconfig_mapping, example_inputs, backend_config=backend_config)
m_fx = prepare_fx(
m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
)
after_prepare_result_fx = m_fx(*example_inputs)
m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config)
@ -579,8 +690,15 @@ class TestQuantizePT2EModels(QuantizationTestCase):
# but we can still manully inspect the printed observers to make sure
# it matches
self.assertEqual(after_prepare_result, after_prepare_result_fx)
self.assertEqual(compute_sqnr(after_prepare_result, after_prepare_result_fx), torch.tensor(float("inf")))
self.assertEqual(
compute_sqnr(after_prepare_result, after_prepare_result_fx),
torch.tensor(float("inf")),
)
# there are slight differences after convert due to different implementations
# of quant/dequant
self.assertTrue(torch.max(after_quant_result - after_quant_result_fx) < 1e-1)
self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) > 35)
self.assertTrue(
torch.max(after_quant_result - after_quant_result_fx) < 1e-1
)
self.assertTrue(
compute_sqnr(after_quant_result, after_quant_result_fx) > 35
)

View File

@ -1,7 +1,7 @@
from .quantizer import Quantizer
from .qnnpack_quantizer import QNNPackQuantizer
from .quantizer import OperatorConfig, Quantizer
__all__ = [
"Quantizer"
"Quantizer",
"QNNPackQuantizer",
]

View File

@ -1,95 +1,80 @@
from __future__ import annotations
from .quantizer import Quantizer
import copy
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Set, Dict, Callable
import operator
from typing import Callable, Dict, List, Optional, Set
import torch
import torch.nn.functional as F
from torch.ao.quantization.observer import (
PlaceholderObserver,
HistogramObserver,
MinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
import torch
import operator
from torch.fx import Node
from .quantizer import (
OperatorConfig,
OperatorPatternType,
QuantizationConfig,
QuantizationSpec,
Quantizer,
)
__all__ = [
"QNNPackQuantizer",
"get_default_symmetric_qnnpack_operator_spec",
"get_default_per_channel_symmetric_qnnpack_operator_spec",
"get_default_symmetric_qnnpack_quantization_config",
"get_default_per_channel_symmetric_qnnpack_quantization_config",
]
# TODO: maybe remove torch.float32
SUPPORTED_DTYPES = [torch.uint8, torch.int8, torch.int32, torch.float16, torch.float32]
SUPPORTED_QSCHEMES = [
torch.per_tensor_affine,
torch.per_tensor_symmetric,
torch.per_channel_affine,
torch.per_channel_symmetric,
torch.per_channel_affine_float_qparams,
]
@dataclass(eq=True, frozen=True)
class TensorSpec:
dtype: torch.dtype
is_dynamic: bool = False
quant_min: Optional[int] = None
quant_max: Optional[int] = None
qscheme: Optional[torch.qscheme] = None
ch_axis: Optional[int] = None
def __post_init__(self):
# check dtype is one of the supported types
if self.dtype not in SUPPORTED_DTYPES:
raise TypeError(f"Unsupported dtype {self.dtype}.")
# quant_min must be less than quant_max
if self.quant_min is not None and self.quant_max is not None and self.quant_min > self.quant_max:
raise ValueError(
f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}."
)
# check qscheme is on of the supported ones
if self.qscheme is not None and self.qscheme not in SUPPORTED_QSCHEMES:
raise ValueError(f"Unsupported qscheme {self.qscheme}.")
# ch_axis must be less than the number of channels
# but no way to check here. Just check that it is not < 0.
if self.ch_axis is not None and self.ch_axis < 0:
raise ValueError("Ch_axis is < 0.")
OperatorSpec = NamedTuple(
"OperatorSpec", [("activation", TensorSpec), ("weight", TensorSpec), ("bias", TensorSpec)]
)
SpecAndOperators = NamedTuple(
"SpecAndOperators",
[("operator_spec", OperatorSpec), ("operators", List[str])]
)
def supported_symmetric_quantized_operators() -> List[str]:
supported_operators = ["conv2d", "linear", "add", "maxpool2d", "hardtanh", "mean", "adaptive_avgpool2d"]
def supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
supported_operators: Dict[str, List[OperatorPatternType]] = {
# Both conv and linear should be able to handle relu + hardtanh fusion since
# those are clamp ops
"conv2d": [
[torch.nn.Conv2d, torch.nn.ReLU],
[torch.nn.Conv2d, F.relu],
[F.conv2d, torch.nn.ReLU],
[F.conv2d, F.relu],
],
"linear": [[torch.nn.Linear], [F.linear]],
"add": [[torch.add]],
"maxpool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]],
"hardtanh": [[torch.nn.Hardtanh], [F.hardtanh]],
"mean": [[torch.mean]],
"adaptive_avgpool2d": [
[torch.nn.AdaptiveAvgPool2d],
[F.adaptive_avg_pool2d],
],
}
return copy.deepcopy(supported_operators)
def get_supported_symmetric_quantized_spec_and_operators() -> List[SpecAndOperators]:
supported_spec_and_operators: List[SpecAndOperators] = []
for operator_spec in [get_default_symmetric_qnnpack_operator_spec(), get_default_per_channel_symmetric_qnnpack_operator_spec()]:
ops = supported_symmetric_quantized_operators()
supported_spec_and_operators.append(SpecAndOperators(operator_spec, ops))
return copy.deepcopy(supported_spec_and_operators)
def get_default_symmetric_qnnpack_operator_spec():
act_tensor_spec = TensorSpec(
def get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
supported_config_and_operators: List[OperatorConfig] = []
for quantization_config in [
get_default_symmetric_qnnpack_quantization_config(),
get_default_per_channel_symmetric_qnnpack_quantization_config(),
]:
ops = supported_symmetric_quantized_operators()
for op_string, pattern_list in ops.items():
supported_config_and_operators.append(
OperatorConfig(quantization_config, pattern_list)
)
return copy.deepcopy(supported_config_and_operators)
def get_default_symmetric_qnnpack_quantization_config():
act_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
)
weight_tensor_spec = TensorSpec(
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
@ -97,19 +82,22 @@ def get_default_symmetric_qnnpack_operator_spec():
ch_axis=1,
is_dynamic=False,
)
bias_tensor_spec = TensorSpec(dtype=torch.float)
operator_spec = OperatorSpec(act_tensor_spec, weight_tensor_spec, bias_tensor_spec)
return operator_spec
bias_quantization_spec = QuantizationSpec(dtype=torch.float)
quantization_config = QuantizationConfig(
act_quantization_spec, weight_quantization_spec, bias_quantization_spec
)
return quantization_config
def get_default_per_channel_symmetric_qnnpack_operator_spec():
act_tensor_spec = TensorSpec(
def get_default_per_channel_symmetric_qnnpack_quantization_config():
act_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
)
weight_tensor_spec = TensorSpec(
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-127,
quant_max=127,
@ -117,27 +105,16 @@ def get_default_per_channel_symmetric_qnnpack_operator_spec():
ch_axis=1,
is_dynamic=False,
)
bias_tensor_spec = TensorSpec(dtype=torch.float)
operator_spec = OperatorSpec(act_tensor_spec, weight_tensor_spec, bias_tensor_spec)
return operator_spec
bias_quantization_spec = QuantizationSpec(dtype=torch.float)
quantization_config = QuantizationConfig(
act_quantization_spec, weight_quantization_spec, bias_quantization_spec
)
return quantization_config
def get_supported_spec_and_operators() -> List[SpecAndOperators]:
return get_supported_symmetric_quantized_spec_and_operators()
class OperatorSpecConfig:
def get_supported_config_and_operators() -> List[OperatorConfig]:
return get_supported_symmetric_config_and_operators()
def __init__(self):
super().__init__()
self.global_spec: Optional[OperatorSpec] = None
self.operator_type_specs: Dict[str, Optional[OperatorSpec]] = {}
def set_global(self, operator_spec: Optional[OperatorSpec]) -> OperatorSpecConfig:
self.global_spec = operator_spec
return self
def set_operator_type(self, operator_type: str, operator_spec: Optional[OperatorSpec]) -> OperatorSpecConfig:
self.operator_type_specs[operator_type] = operator_spec
return self
# TODO: add support for torch dtype in quant code base
# this includes observers and prepare/convert code
@ -147,61 +124,76 @@ _TORCH_DTYPE_TO_QDTYPE = {
torch.int32: torch.qint32,
torch.float16: torch.float16,
}
def _get_act_obs_or_fq_ctr(operator_spec: Optional[OperatorSpec]):
if operator_spec is None:
def _get_act_obs_or_fq_ctr(quantization_config: Optional[QuantizationConfig]):
if quantization_config is None:
return None
assert operator_spec is not None
tensor_spec: TensorSpec = operator_spec.activation
qdtype = _TORCH_DTYPE_TO_QDTYPE[tensor_spec.dtype]
assert tensor_spec.qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]
if not tensor_spec.is_dynamic:
assert quantization_config is not None
quantization_spec: QuantizationSpec = quantization_config.activation
qdtype = _TORCH_DTYPE_TO_QDTYPE[quantization_spec.dtype]
assert quantization_spec.qscheme in [
torch.per_tensor_affine,
torch.per_tensor_symmetric,
]
if not quantization_spec.is_dynamic:
return HistogramObserver.with_args(
dtype=qdtype,
quant_min=tensor_spec.quant_min,
quant_max=tensor_spec.quant_max,
quant_min=quantization_spec.quant_min,
quant_max=quantization_spec.quant_max,
reduce_range=False,
eps=2**-12
eps=2**-12,
)
else:
# TODO: extend this helper function to support dynamic quantization
raise Exception("Unsupported tensor_spec for activation: {}".format(tensor_spec))
def _get_weight_obs_or_fq_ctr(operator_spec: Optional[OperatorSpec]):
if operator_spec is None:
return None
assert operator_spec is not None
tensor_spec: TensorSpec = operator_spec.weight
qdtype = _TORCH_DTYPE_TO_QDTYPE[tensor_spec.dtype]
if tensor_spec.qscheme == torch.per_tensor_symmetric:
return MinMaxObserver.with_args(
qscheme=tensor_spec.qscheme,
dtype=qdtype,
quant_min=tensor_spec.quant_min,
quant_max=tensor_spec.quant_max,
eps=2**-12
raise Exception(
"Unsupported quantization_spec for activation: {}".format(quantization_spec)
)
elif tensor_spec.qscheme == torch.per_channel_symmetric:
return PerChannelMinMaxObserver.with_args(
qscheme=tensor_spec.qscheme,
def _get_weight_obs_or_fq_ctr(quantization_config: Optional[QuantizationConfig]):
if quantization_config is None:
return None
assert quantization_config is not None
quantization_spec: QuantizationSpec = quantization_config.weight
qdtype = _TORCH_DTYPE_TO_QDTYPE[quantization_spec.dtype]
if quantization_spec.qscheme == torch.per_tensor_symmetric:
return MinMaxObserver.with_args(
qscheme=quantization_spec.qscheme,
dtype=qdtype,
quant_min=tensor_spec.quant_min,
quant_max=tensor_spec.quant_max,
eps=2**-12
quant_min=quantization_spec.quant_min,
quant_max=quantization_spec.quant_max,
eps=2**-12,
)
elif quantization_spec.qscheme == torch.per_channel_symmetric:
return PerChannelMinMaxObserver.with_args(
qscheme=quantization_spec.qscheme,
dtype=qdtype,
quant_min=quantization_spec.quant_min,
quant_max=quantization_spec.quant_max,
eps=2**-12,
)
else:
raise Exception("Unsupported tensor_spec for weight: {}".format(tensor_spec))
raise Exception(
"Unsupported quantization_spec for weight: {}".format(quantization_spec)
)
def _get_bias_obs_or_fq_ctr(operator_spec: Optional[OperatorSpec]):
if operator_spec is None:
def _get_bias_obs_or_fq_ctr(quantization_config: Optional[QuantizationConfig]):
if quantization_config is None:
return None
assert operator_spec is not None
tensor_spec: TensorSpec = operator_spec.bias
assert tensor_spec.dtype == torch.float, "Only float dtype for bias is supported for bias right now"
return PlaceholderObserver.with_args(dtype=tensor_spec.dtype)
assert quantization_config is not None
quantization_spec: QuantizationSpec = quantization_config.bias
assert (
quantization_spec.dtype == torch.float
), "Only float dtype for bias is supported for bias right now"
return PlaceholderObserver.with_args(dtype=quantization_spec.dtype)
def _get_default_obs_or_fq_ctr():
return PlaceholderObserver.with_args(dtype=torch.float)
def _is_annotated(nodes: List[Node]):
"""
Given a list of nodes (that represents an operator pattern),
@ -210,96 +202,106 @@ def _is_annotated(nodes: List[Node]):
"""
annotated = False
for node in nodes:
annotated = annotated or ("target_dtype_info" in node.meta and node.meta["target_dtype_info"].get("_annotated", False))
annotated = annotated or (
"target_dtype_info" in node.meta
and node.meta["target_dtype_info"].get("_annotated", False)
)
return annotated
class QNNPackQuantizer(Quantizer):
supported_spec_and_operators = get_supported_spec_and_operators()
supported_config_and_operators = get_supported_config_and_operators()
def __init__(self):
super().__init__()
self.operator_spec_config = OperatorSpecConfig()
self.global_config: Optional[QuantizationConfig] = None
self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {}
@classmethod
def get_supported_operator_specs(cls) -> List[OperatorSpec]:
op_specs: Set[OperatorSpec] = set({})
for spec, _ in cls.supported_spec_and_operators:
op_specs.add(spec)
return list(op_specs)
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = set({})
for spec, _ in cls.supported_config_and_operators:
op_configs.add(spec)
return list(op_configs)
@classmethod
def get_supported_operator_for_operator_spec(cls, operator_spec: Optional[OperatorSpec]) -> List[str]:
if operator_spec is None:
def get_supported_operator_for_quantization_config(
cls, quantization_config: Optional[QuantizationConfig]
) -> List[OperatorPatternType]:
if quantization_config is None:
all_ops = []
for _, ops in cls.supported_spec_and_operators:
for _, ops in cls.supported_config_and_operators:
all_ops.extend(ops)
return all_ops
for spec, ops in cls.supported_spec_and_operators:
for config, ops in cls.supported_config_and_operators:
# note: this assumes each entry in cls.supported_spec_and_operators
# corresponds to one spec, e.g. we don't have
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
# where the first and second entry have the same spec but did not
# merge the op list
if spec == operator_spec:
if config == quantization_config:
return ops
return []
def set_global(self, operator_spec: Optional[OperatorSpec]) -> QNNPackQuantizer:
self.operator_spec_config.set_global(operator_spec)
def set_global(
self, quantization_config: Optional[QuantizationConfig]
) -> QNNPackQuantizer:
self.global_config = quantization_config
return self
def set_spec_for_operator_type(
self, operator_type: str, operator_spec: Optional[OperatorSpec]
def set_config_for_operator_type(
self, operator_type: str, quantization_config: Optional[QuantizationConfig]
) -> QNNPackQuantizer:
self.operator_spec_config.set_operator_type(operator_type, operator_spec)
self.operator_type_config[operator_type] = quantization_config
return self
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
""" just handling global spec for now
"""
global_spec = self.operator_spec_config.global_spec
ops = self.get_supported_operator_for_operator_spec(global_spec)
"""just handling global spec for now"""
global_config = self.global_config
ops = self.get_supported_operator_for_quantization_config(global_config)
# annotate the nodes from last to first since the matching is in the reversed order
# and fusion operator patterns (conv - relu) can get matched before single operator pattern (conv)
# and we will mark the matched node with "_annoated" so fusion operator pattern
# can take precedence over single operator pattern in this way
for node in reversed(model.graph.nodes):
for op in ops:
if op == "conv2d":
self._annotate_conv2d_relu(node, global_spec)
self._annotate_conv2d(node, global_spec)
elif op == "linear":
self._annotate_linear(node, global_spec)
elif op == "maxpool2d":
self._annotate_maxpool2d(node, global_spec)
elif op == "add":
self._annotate_add_relu(node, global_spec)
self._annotate_add(node, global_spec)
elif op == "hardtanh":
self._annotate_hardtanh(node, global_spec)
elif op == "mean":
self._annotate_mean(node, global_spec)
elif op == "adaptive_avgpool2d":
self._annotate_adaptive_avg_pool2d(node, global_spec)
# one improvement is to register node annotators for each
# supported op type.
self._annotate_conv2d_relu(node, global_config)
self._annotate_conv2d(node, global_config)
self._annotate_linear(node, global_config)
self._annotate_maxpool2d(node, global_config)
self._annotate_add_relu(node, global_config)
self._annotate_add(node, global_config)
self._annotate_hardtanh(node, global_config)
self._annotate_mean(node, global_config)
self._annotate_adaptive_avg_pool2d(node, global_config)
return model
def _annotate_conv2d_relu(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
if node.op != "call_function" or node.target not in [torch.ops.aten.relu_.default, torch.ops.aten.relu.default]:
def _annotate_conv2d_relu(
self, node: Node, quantization_config: Optional[QuantizationConfig]
) -> None:
if node.op != "call_function" or node.target not in [
torch.ops.aten.relu_.default,
torch.ops.aten.relu.default,
]:
return
relu_node = node
conv_node = relu_node.args[0]
assert isinstance(conv_node, Node)
if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default:
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.convolution.default
):
return
if _is_annotated([relu_node, conv_node]):
return
conv_node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
"weight_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(operator_spec),
"bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(operator_spec),
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
"weight_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(quantization_config),
"bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(quantization_config),
"output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
# TODO: validation of weight_index must be set if weight_obs_or_fq_ctr is set
"weight_index": 1,
@ -309,22 +311,27 @@ class QNNPackQuantizer(Quantizer):
}
relu_node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
"_annotated": True,
}
def _annotate_conv2d(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
def _annotate_conv2d(
self, node: Node, quantization_config: Optional[QuantizationConfig]
) -> None:
conv_node = node
if conv_node.op != "call_function" or conv_node.target != torch.ops.aten.convolution.default:
if (
conv_node.op != "call_function"
or conv_node.target != torch.ops.aten.convolution.default
):
return
# skip annotation if it is already annotated
if _is_annotated([conv_node]):
return
conv_node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
"weight_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(operator_spec),
"bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(operator_spec),
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
"weight_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(quantization_config),
"bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(quantization_config),
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
# TODO: validation of weight_index must be set if weight_obs_or_fq_ctr is set
"weight_index": 1,
# TODO: validation of bias_index must be set if bias_obs_or_fq_ctr is set
@ -332,13 +339,21 @@ class QNNPackQuantizer(Quantizer):
"_annotated": True,
}
def _annotate_linear(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
def _annotate_linear(
self, node: Node, quantization_config: Optional[QuantizationConfig]
) -> None:
addmm_node = node
if addmm_node.op != "call_function" or addmm_node.target != torch.ops.aten.addmm.default:
if (
addmm_node.op != "call_function"
or addmm_node.target != torch.ops.aten.addmm.default
):
return
view_node = addmm_node.args[1]
assert isinstance(view_node, Node)
if view_node.op != "call_function" or view_node.target != torch.ops.aten.view.default:
if (
view_node.op != "call_function"
or view_node.target != torch.ops.aten.view.default
):
return
t_node = addmm_node.args[2]
assert isinstance(t_node, Node)
@ -349,106 +364,154 @@ class QNNPackQuantizer(Quantizer):
# bias and output act
addmm_node.meta["target_dtype_info"] = {
"bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(operator_spec),
"bias_obs_or_fq_ctr": _get_bias_obs_or_fq_ctr(quantization_config),
"input_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
"bias_index": 0,
"_annotated": True,
}
# input act
view_node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
"output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
"_annotated": True,
}
# weight
t_node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(operator_spec),
"input_act_obs_or_fq_ctr": _get_weight_obs_or_fq_ctr(quantization_config),
"output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
"_annotated": True,
}
def _annotate_maxpool2d(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
if node.op != "call_function" or node.target != operator.getitem or node.args[1] != 0:
def _annotate_maxpool2d(
self, node: Node, quantization_config: Optional[QuantizationConfig]
) -> None:
if (
node.op != "call_function"
or node.target != operator.getitem
or node.args[1] != 0
):
return
getitem_node = node
maxpool_node = getitem_node.args[0]
assert isinstance(maxpool_node, Node)
if maxpool_node.op != "call_function" or maxpool_node.target != torch.ops.aten.max_pool2d_with_indices.default:
if (
maxpool_node.op != "call_function"
or maxpool_node.target != torch.ops.aten.max_pool2d_with_indices.default
):
return
if _is_annotated([getitem_node, maxpool_node]):
return
maxpool_node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
"output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
"_annotated": True,
}
getitem_node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
"input_output_share_observers": True,
"_annotated": True,
}
def _annotate_input_out_obs_sharing_op(self, op: Callable, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
def _annotate_input_out_obs_sharing_op(
self,
op: Callable,
node: Node,
quantization_config: Optional[QuantizationConfig],
) -> None:
io_obs_sharing_node = node
if io_obs_sharing_node.op != "call_function" or io_obs_sharing_node.target != op:
if (
io_obs_sharing_node.op != "call_function"
or io_obs_sharing_node.target != op
):
return
if _is_annotated([io_obs_sharing_node]):
return
io_obs_sharing_node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
"input_output_share_observers": True,
"_annotated": True,
}
def _annotate_hardtanh(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
self._annotate_input_out_obs_sharing_op(torch.ops.aten.hardtanh.default, node, operator_spec)
def _annotate_hardtanh(
self, node: Node, quantization_config: Optional[QuantizationConfig]
) -> None:
self._annotate_input_out_obs_sharing_op(
torch.ops.aten.hardtanh.default, node, quantization_config
)
def _annotate_mean(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
self._annotate_input_out_obs_sharing_op(torch.ops.aten.mean.default, node, operator_spec)
self._annotate_input_out_obs_sharing_op(torch.ops.aten.mean.dim, node, operator_spec)
def _annotate_mean(
self, node: Node, quantization_config: Optional[QuantizationConfig]
) -> None:
self._annotate_input_out_obs_sharing_op(
torch.ops.aten.mean.default, node, quantization_config
)
self._annotate_input_out_obs_sharing_op(
torch.ops.aten.mean.dim, node, quantization_config
)
def _annotate_adaptive_avg_pool2d(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
self._annotate_input_out_obs_sharing_op(torch.ops.aten.adaptive_avg_pool2d.default, node, operator_spec)
def _annotate_adaptive_avg_pool2d(
self, node: Node, quantization_config: Optional[QuantizationConfig]
) -> None:
self._annotate_input_out_obs_sharing_op(
torch.ops.aten.adaptive_avg_pool2d.default, node, quantization_config
)
def _annotate_add_relu(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
if node.op != "call_function" or node.target not in [torch.ops.aten.relu_.default, torch.ops.aten.relu.default]:
def _annotate_add_relu(
self, node: Node, quantization_config: Optional[QuantizationConfig]
) -> None:
if node.op != "call_function" or node.target not in [
torch.ops.aten.relu_.default,
torch.ops.aten.relu.default,
]:
return
relu_node = node
add_node = relu_node.args[0]
assert isinstance(add_node, Node)
if add_node.op != "call_function" or add_node.target not in [torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor]:
if add_node.op != "call_function" or add_node.target not in [
torch.ops.aten.add.Tensor,
torch.ops.aten.add_.Tensor,
]:
return
if _is_annotated([relu_node, add_node]):
return
add_node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
"output_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
"_annotated": True,
}
relu_node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": _get_default_obs_or_fq_ctr(),
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
"_annotated": True,
}
def _annotate_add(self, node: Node, operator_spec: Optional[OperatorSpec]) -> None:
def _annotate_add(
self, node: Node, quantization_config: Optional[QuantizationConfig]
) -> None:
add_node = node
if add_node.op != "call_function" or add_node.target not in [torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor]:
if add_node.op != "call_function" or add_node.target not in [
torch.ops.aten.add.Tensor,
torch.ops.aten.add_.Tensor,
]:
return
if _is_annotated([add_node]):
return
add_node.meta["target_dtype_info"] = {
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(operator_spec),
"input_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
"output_act_obs_or_fq_ctr": _get_act_obs_or_fq_ctr(quantization_config),
"_annotated": True,
}
def validate(self, model: torch.fx.GraphModule) -> None:
pass
@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
return cls.supported_config_and_operators

View File

@ -1,11 +1,90 @@
from abc import ABC
from abc import abstractmethod
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, List, NamedTuple, Optional
import torch
__all__ = [
"Quantizer",
]
# TODO: maybe remove torch.float32
SUPPORTED_DTYPES = [torch.uint8, torch.int8, torch.int32, torch.float16, torch.float32]
SUPPORTED_QSCHEMES = [
torch.per_tensor_affine,
torch.per_tensor_symmetric,
torch.per_channel_affine,
torch.per_channel_symmetric,
torch.per_channel_affine_float_qparams,
]
@dataclass(eq=True, frozen=True)
class QuantizationSpec:
dtype: torch.dtype
is_dynamic: bool = False
quant_min: Optional[int] = None
quant_max: Optional[int] = None
qscheme: Optional[torch.qscheme] = None
ch_axis: Optional[int] = None
def __post_init__(self):
# check dtype is one of the supported types
if self.dtype not in SUPPORTED_DTYPES:
raise TypeError(f"Unsupported dtype {self.dtype}.")
# quant_min must be less than quant_max
if (
self.quant_min is not None
and self.quant_max is not None
and self.quant_min > self.quant_max
):
raise ValueError(
f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}."
)
# check qscheme is on of the supported ones
if self.qscheme is not None and self.qscheme not in SUPPORTED_QSCHEMES:
raise ValueError(f"Unsupported qscheme {self.qscheme}.")
# ch_axis must be less than the number of channels
# but no way to check here. Just check that it is not < 0.
if self.ch_axis is not None and self.ch_axis < 0:
raise ValueError("Ch_axis is < 0.")
# In the absence of better name, just winging it with QuantizationConfig
QuantizationConfig = NamedTuple(
"QuantizationConfig",
[
("activation", QuantizationSpec),
("weight", QuantizationSpec),
("bias", QuantizationSpec),
],
)
OperatorPatternType = List[Callable]
OperatorConfig = NamedTuple(
"OperatorConfig",
# fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]]
# Basically we are mapping a quantization config to some list of patterns.
# a pattern is defined as a list of nn module, function or builtin function names
# e.g. [nn.Conv2d, torch.relu, torch.add]
# We have not resolved whether fusion can be considered internal details of the
# quantizer hence it does not need communication to user.
# Note this pattern is not really informative since it does not really
# tell us the graph structure resulting from the list of ops.
[
("config", QuantizationConfig),
(
"operators",
List[OperatorPatternType],
),
],
)
class Quantizer(ABC):
# annotate nodes in the graph with observer or fake quant constructors
@ -18,3 +97,10 @@ class Quantizer(ABC):
@abstractmethod
def validate(self, model: torch.fx.GraphModule) -> None:
pass
# annotate nodes in the graph with observer or fake quant constructors
# to convey the desired way of quantization
@classmethod
@abstractmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
pass