mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73863 This PR fully aligns the convert function with the design: https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md and simplifies the implementation of convert function by always produce a reference quantized model (with reference patterns) first, and then lower the model to a quantized model that is runnable with PyTorch native backend (fbgemm/qnnpack). This PR makes the convert.py much easier to understand than the previous implementation, and we are able to remove majority of code in quantization_patterns.py as well (in followup PRs). Test Plan: ``` python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps python test/test_quantization.py TestFXNumericSuiteCoreAPIs python test/test_quantization.py TestFXNumericSuiteCoreAPIsModels ``` and other internal/oss regression tests Imported from OSS Reviewed By: andrewor14 Differential Revision: D34778506 fbshipit-source-id: 0678b66addf736039a8749b352f6f569caca962b (cherry picked from commit 33ec9caf23f3ab373d827117efbd9db0668b2437)
1680 lines
79 KiB
Python
1680 lines
79 KiB
Python
import torch
|
|
from torch.fx import GraphModule
|
|
from torch.fx.graph import (
|
|
Node,
|
|
Graph,
|
|
)
|
|
from ..observer import (
|
|
default_affine_fixed_qparams_observer,
|
|
default_symmetric_fixed_qparams_observer,
|
|
)
|
|
|
|
from ..quantization_mappings import (
|
|
get_static_quant_module_class,
|
|
get_dynamic_quant_module_class,
|
|
)
|
|
from ..utils import (
|
|
_parent_name,
|
|
get_swapped_custom_module_class,
|
|
activation_is_statically_quantized,
|
|
activation_is_int8_quantized,
|
|
weight_is_statically_quantized,
|
|
get_qconfig_dtypes,
|
|
activation_dtype,
|
|
get_qparam_dict,
|
|
)
|
|
|
|
from torch.ao.quantization.quantize import (
|
|
is_activation_post_process,
|
|
)
|
|
|
|
from .pattern_utils import (
|
|
register_quant_pattern,
|
|
get_default_output_activation_post_process_map,
|
|
Pattern,
|
|
)
|
|
from .utils import (
|
|
all_node_args_have_no_tensors,
|
|
quantize_node,
|
|
get_per_tensor_qparams,
|
|
get_linear_prepack_op_for_dtype,
|
|
create_qparam_nodes,
|
|
get_qconv_prepack_op,
|
|
get_qconv_op,
|
|
create_node_from_old_node_preserve_meta,
|
|
)
|
|
|
|
from ..qconfig import QConfigAny
|
|
|
|
from abc import ABC
|
|
import operator
|
|
import warnings
|
|
|
|
from typing import Any, Callable, Dict, Union, Optional, Tuple, List
|
|
|
|
# -------------------------
|
|
# Pattern Registrations
|
|
# -------------------------
|
|
|
|
# 1. Post Training Static Quantization and Quantization Aware Training Patterns
|
|
|
|
# Base Pattern Handler
|
|
class QuantizeHandler(ABC):
|
|
""" Base handler class for the quantizer patterns
|
|
"""
|
|
def __init__(self, node: Node, modules: Dict[str, torch.nn.Module]):
|
|
""" Records pattern information in __init__, which will be used
|
|
in convert
|
|
"""
|
|
# this is an indicator of whether all the inputs are Node or not
|
|
# since some op might be quantized differently depending on whether
|
|
# all inputs are tensors or not, e.g. add/mul
|
|
self.num_tensor_args = len(node.args)
|
|
self.all_node_args_are_tensors = True
|
|
# the last node of the matched pattern
|
|
self.last_node = node
|
|
|
|
def _maybe_get_last_node_only_observer(
|
|
self,
|
|
modules: Dict[str, torch.nn.Module]
|
|
) -> Optional[torch.nn.Module]:
|
|
"""
|
|
If the last node of the pattern is observed, return the observer
|
|
instance. Otherwise, return None.
|
|
"""
|
|
for maybe_obs_node, _ in self.last_node.users.items():
|
|
if maybe_obs_node.op == 'call_module':
|
|
maybe_obs = modules[str(maybe_obs_node.target)]
|
|
if is_activation_post_process(maybe_obs):
|
|
return maybe_obs
|
|
return None
|
|
|
|
def input_output_observed(self) -> bool:
|
|
"""
|
|
Returns True if the pattern matched to this qhandler could be
|
|
be observed, and False it it should not be observed.
|
|
"""
|
|
return True
|
|
|
|
def is_general_tensor_value_op(self) -> bool:
|
|
"""
|
|
Returns True if the operator works for both floating point and
|
|
quantized input, and does some computation based on the input Tensor,
|
|
so we need to insert observer/fake_quant for the output of the
|
|
operator since the distribution of values is different for input and output
|
|
Tensors (for HistogramObserver)
|
|
while they share the same quantization parameters
|
|
Example: avgpool2d
|
|
"""
|
|
return False
|
|
|
|
def is_general_tensor_shape_op(self) -> bool:
|
|
""" Similar to is_general_tensor_value_op, this is a check
|
|
for ops that works for both floating point and quantized input,
|
|
that only re-arranges the Tensor values or query some metadata about the Tensor
|
|
We don't insert observer/fake_quant for the output of these operators
|
|
Example: reshape, transpose, maxpool2d
|
|
"""
|
|
return False
|
|
|
|
def should_insert_observer_for_output(
|
|
self,
|
|
qconfig: Any,
|
|
model_is_training: bool,
|
|
) -> bool:
|
|
"""
|
|
Returns true if an observer should be inserted for the output of
|
|
the pattern matched to this QuantizeHandler instance during the
|
|
prepare step.
|
|
"""
|
|
# TODO(future PR): potentially clean up and deduplicate these
|
|
# mappings.
|
|
return self.all_node_args_are_tensors and self.input_output_observed()
|
|
|
|
def should_mark_output_quantized_from_input_quantized_status(
|
|
self,
|
|
qconfig: QConfigAny
|
|
) -> bool:
|
|
"""
|
|
Returns true if after convert, the output of the matched pattern is
|
|
quantized iff the first input is also quantized.
|
|
"""
|
|
return False
|
|
|
|
def get_activation_ctr(
|
|
self,
|
|
qconfig: Any,
|
|
pattern: Pattern,
|
|
is_training: bool,
|
|
) -> Optional[Callable]:
|
|
"""
|
|
Returns the constructor for the activation observer which should be
|
|
used for the pattern matched to this handler. Some handlers override
|
|
this to a different value than what is specified in the qconfig.
|
|
"""
|
|
return qconfig.activation
|
|
|
|
def is_output_quantized(self, qconfig):
|
|
""" Returns true if the output node of convert is quantized
|
|
when is_reference is False, we would return float node when a certain dtype
|
|
combination is not supported (since fbgemm/qnnpack only support certain dtype
|
|
combinations), so the output may be float, but when is_reference is True,
|
|
we support all dtype combinations so the output will always be quantized.
|
|
|
|
TODO: This is fragile, whether output is quantized should not depend on `is_reference` since
|
|
we want to make sure whether a Tensor is quantized
|
|
should be the same in prepare and convert and is_reference
|
|
is only available in convert currently
|
|
|
|
"""
|
|
return True
|
|
|
|
def convert(self,
|
|
node: Node,
|
|
qconfig: QConfigAny,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
""" Convert the given node to a quantized node and insert
|
|
it to the quantized graph
|
|
"""
|
|
return NotImplemented
|
|
|
|
|
|
# Binary op configs
|
|
|
|
# Supported combinations are:
|
|
# quant_type | activation (compute_type) | weight
|
|
# static quint8 qint8
|
|
|
|
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
|
# these are supported types for common binary ops like add/mul etc.
|
|
all_dtypes = [
|
|
(torch.qint8, torch.qint8, None),
|
|
(torch.quint8, torch.qint8, None),
|
|
(torch.float16, torch.float16, None),
|
|
]
|
|
fp16_dtypes = [
|
|
(torch.float16, torch.float16, None)
|
|
]
|
|
int8_dtypes = [
|
|
(torch.qint8, torch.qint8, None),
|
|
(torch.quint8, torch.qint8, None),
|
|
]
|
|
binary_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = {
|
|
operator.add: all_dtypes,
|
|
torch.add: all_dtypes,
|
|
operator.mul: all_dtypes,
|
|
torch.mul: all_dtypes,
|
|
torch.bmm: fp16_dtypes,
|
|
torch.sub: fp16_dtypes,
|
|
operator.sub: fp16_dtypes,
|
|
torch.div: fp16_dtypes,
|
|
operator.truediv: fp16_dtypes,
|
|
torch.matmul: int8_dtypes,
|
|
}
|
|
|
|
default_op_supported_dtypes = {
|
|
torch.nn.ConvTranspose1d: int8_dtypes,
|
|
torch.nn.ConvTranspose2d: int8_dtypes,
|
|
torch.nn.ELU: int8_dtypes,
|
|
torch.nn.LeakyReLU: int8_dtypes,
|
|
torch.nn.Hardswish: int8_dtypes,
|
|
torch.nn.InstanceNorm1d: int8_dtypes,
|
|
torch.nn.InstanceNorm2d: int8_dtypes,
|
|
torch.nn.InstanceNorm3d: int8_dtypes,
|
|
torch.nn.LayerNorm: all_dtypes,
|
|
torch.nn.SiLU: fp16_dtypes,
|
|
torch.nn.Mish: fp16_dtypes,
|
|
torch.nn.GELU: int8_dtypes,
|
|
torch.nn.Dropout: int8_dtypes,
|
|
torch.nn.Softmax: int8_dtypes,
|
|
torch.nn.functional.elu: int8_dtypes,
|
|
torch.nn.functional.hardswish: int8_dtypes,
|
|
torch.nn.functional.instance_norm: int8_dtypes,
|
|
torch.nn.functional.layer_norm: all_dtypes,
|
|
torch.nn.functional.leaky_relu: int8_dtypes,
|
|
torch.nn.functional.silu: fp16_dtypes,
|
|
torch.nn.functional.mish: fp16_dtypes,
|
|
torch.nn.functional.gelu: int8_dtypes,
|
|
torch.nn.functional.softmax: int8_dtypes,
|
|
torch.nn.functional.dropout: int8_dtypes,
|
|
torch.sum: fp16_dtypes,
|
|
}
|
|
|
|
QAT_CONV_MODULE_CLASSES = \
|
|
(torch.nn.qat.Conv2d,
|
|
torch.nn.qat.Conv3d,
|
|
torch.nn.intrinsic.qat.ConvBn1d,
|
|
torch.nn.intrinsic.qat.ConvBn2d,
|
|
torch.nn.intrinsic.qat.ConvBn3d,
|
|
torch.nn.intrinsic.qat.ConvBnReLU1d,
|
|
torch.nn.intrinsic.qat.ConvBnReLU2d,
|
|
torch.nn.intrinsic.qat.ConvBnReLU3d,
|
|
torch.nn.intrinsic.qat.ConvReLU2d,
|
|
torch.nn.intrinsic.qat.ConvReLU3d)
|
|
|
|
##########################
|
|
# Helper Functions
|
|
##########################
|
|
|
|
def _load_weight_qparams(
|
|
self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
key = prefix + "_weight_qparams"
|
|
if key in state_dict:
|
|
self._weight_qparams = state_dict[key]
|
|
state_dict.pop(key)
|
|
|
|
def _save_weight_qparams(self, destination, prefix, keep_vars):
|
|
for attr_name in dir(self):
|
|
if "_weight_qparams" == attr_name and \
|
|
isinstance(getattr(self, attr_name), dict):
|
|
weight_qparams = getattr(self, attr_name)
|
|
destination[prefix + attr_name] = weight_qparams
|
|
|
|
|
|
def _to_reference(float_module, weight_qparams):
|
|
""" Make a weighted float module (e.g. conv and linear )a reference module by
|
|
attaching _weight_qparams that records the qparams for weight
|
|
and change the name for the module so that it's recognized
|
|
when people print the model
|
|
"""
|
|
float_module._weight_qparams = weight_qparams
|
|
float_module._register_state_dict_hook(_save_weight_qparams)
|
|
float_module._register_load_state_dict_pre_hook(_load_weight_qparams, with_module=True)
|
|
|
|
float_module_name = float_module._get_name()
|
|
|
|
def _get_name():
|
|
return float_module_name + "(Reference)"
|
|
|
|
float_module._get_name = _get_name
|
|
|
|
@register_quant_pattern(operator.add)
|
|
@register_quant_pattern(operator.sub)
|
|
@register_quant_pattern(operator.mul)
|
|
@register_quant_pattern(operator.truediv)
|
|
@register_quant_pattern(torch.add)
|
|
@register_quant_pattern(torch.sub)
|
|
@register_quant_pattern(torch.mul)
|
|
@register_quant_pattern(torch.div)
|
|
@register_quant_pattern(torch.bmm)
|
|
@register_quant_pattern((torch.nn.ReLU, operator.add))
|
|
@register_quant_pattern((torch.nn.ReLU, operator.mul))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.add))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.mul))
|
|
@register_quant_pattern((torch.nn.functional.relu, operator.add))
|
|
@register_quant_pattern((torch.nn.functional.relu, operator.mul))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.add))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.mul))
|
|
@register_quant_pattern((torch.relu, operator.add))
|
|
@register_quant_pattern((torch.relu, operator.mul))
|
|
@register_quant_pattern(torch.matmul)
|
|
class BinaryOpQuantizeHandler(QuantizeHandler):
|
|
def __init__(
|
|
self,
|
|
node: Node,
|
|
modules: Dict[str, torch.nn.Module]):
|
|
super().__init__(node, modules)
|
|
self.relu_node = None
|
|
if (
|
|
node.op == 'call_function' and
|
|
node.target in (torch.nn.functional.relu, torch.relu)
|
|
) or (
|
|
node.op == 'call_module' and
|
|
isinstance(modules[str(node.target)], torch.nn.ReLU)
|
|
):
|
|
self.relu_node = node
|
|
node = node.args[0] # type: ignore[assignment]
|
|
self.binary_op_node = node
|
|
self.binary_op = node.target
|
|
|
|
# determine how many of the first two args are Tensors (versus scalars)
|
|
# this distinguishes things like "x + y" from "x + 2" or "2 + x"
|
|
self.num_tensor_args = 0
|
|
cache_for_no_tensor_check: Dict[Node, bool] = dict()
|
|
for arg_idx in range(len(self.binary_op_node.args)):
|
|
arg = self.binary_op_node.args[arg_idx]
|
|
if isinstance(arg, Node) and (not all_node_args_have_no_tensors(arg, modules, cache_for_no_tensor_check)):
|
|
self.num_tensor_args += 1
|
|
self.all_node_args_are_tensors = \
|
|
(self.num_tensor_args == len(self.binary_op_node.args))
|
|
|
|
def should_insert_observer_for_output(
|
|
self,
|
|
qconfig: Any,
|
|
model_is_training: bool,
|
|
) -> bool:
|
|
"""
|
|
Returns true if an observer should be inserted for the output of
|
|
the pattern matched to this QuantizeHandler instance during the
|
|
prepare step.
|
|
"""
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
if not (self.binary_op in binary_op_supported_dtypes and dtypes in binary_op_supported_dtypes[self.binary_op]):
|
|
return False
|
|
if self.num_tensor_args == 1:
|
|
return True
|
|
elif self.all_node_args_are_tensors and self.input_output_observed():
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def is_general_tensor_value_op(self) -> bool:
|
|
return self.num_tensor_args == 1
|
|
|
|
def input_output_observed(self):
|
|
# for x + y where x and y are scalars, we do not observe anything
|
|
return self.num_tensor_args > 0
|
|
|
|
def is_output_quantized(self, qconfig):
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
return self.binary_op in binary_op_supported_dtypes and \
|
|
dtypes in binary_op_supported_dtypes[self.binary_op]
|
|
|
|
def convert(self,
|
|
node: Node,
|
|
qconfig: QConfigAny,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
|
|
if self.num_tensor_args == 0:
|
|
# example: x + y, when x and y are scalars
|
|
return quantized_graph.node_copy(
|
|
node, load_arg(quantized=None))
|
|
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
|
|
act_dtype = activation_dtype(qconfig)
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
if act_dtype == torch.float or \
|
|
not (self.binary_op in binary_op_supported_dtypes and dtypes in binary_op_supported_dtypes[self.binary_op]):
|
|
if self.relu_node:
|
|
op_out = quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=torch.float))
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=torch.float)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=torch.float)(self.relu_node.kwargs)
|
|
return create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
("call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs),
|
|
self.relu_node)
|
|
else:
|
|
return quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
else:
|
|
if self.num_tensor_args == 2:
|
|
# make sure both inputs are quantized to act_dtype
|
|
load_arg(quantized={0: act_dtype, 1: act_dtype})(self.binary_op_node.args)
|
|
args = load_arg(quantized=torch.float)(self.binary_op_node.args)
|
|
kwargs = load_arg(quantized=torch.float)(self.binary_op_node.kwargs)
|
|
op_out = quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=torch.float))
|
|
|
|
def modified_load_arg(n: Node):
|
|
if n.name == self.binary_op_node.name:
|
|
return op_out
|
|
else:
|
|
return load_arg(quantized=torch.float)(n)
|
|
|
|
if self.relu_node:
|
|
op_out = quantized_graph.node_copy(self.relu_node, modified_load_arg)
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(modules)
|
|
assert activation_post_process is not None
|
|
return quantize_node(
|
|
op_out, activation_post_process,
|
|
node, modules, quantized_graph, node_name_to_scope, is_input=False)
|
|
|
|
@register_quant_pattern(torch.cat)
|
|
class CatQuantizeHandler(QuantizeHandler):
|
|
def is_general_tensor_value_op(self) -> bool:
|
|
return True
|
|
|
|
def convert(self,
|
|
node: Node,
|
|
qconfig: QConfigAny,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
if not self.all_node_args_are_tensors:
|
|
return NotImplemented
|
|
act_dtype = activation_dtype(qconfig)
|
|
if act_dtype == torch.float:
|
|
op_out = quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
return op_out
|
|
else:
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(modules)
|
|
assert activation_post_process is not None
|
|
# make sure the first argument is quantized to act_dtype
|
|
load_arg(quantized={0: act_dtype})(node.args)
|
|
args = list(load_arg(quantized=torch.float)(node.args))
|
|
kwargs = load_arg(quantized=torch.float)(node.kwargs)
|
|
op_out = quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
return quantize_node(
|
|
op_out,
|
|
activation_post_process,
|
|
node,
|
|
modules,
|
|
quantized_graph,
|
|
node_name_to_scope,
|
|
is_input=False)
|
|
|
|
# handle conv, maybe followed by relu
|
|
# NB: matching order is reversed, that is we match from the bottom of this list to the beginning
|
|
@register_quant_pattern(torch.nn.Conv1d)
|
|
@register_quant_pattern(torch.nn.Conv2d)
|
|
@register_quant_pattern(torch.nn.Conv3d)
|
|
@register_quant_pattern(torch.nn.functional.conv1d)
|
|
@register_quant_pattern(torch.nn.functional.conv2d)
|
|
@register_quant_pattern(torch.nn.functional.conv3d)
|
|
# TODO: add qat.Conv1d
|
|
@register_quant_pattern(torch.nn.qat.Conv2d)
|
|
@register_quant_pattern(torch.nn.qat.Conv3d)
|
|
@register_quant_pattern(torch.nn.intrinsic.ConvReLU1d)
|
|
@register_quant_pattern(torch.nn.intrinsic.ConvReLU2d)
|
|
@register_quant_pattern(torch.nn.intrinsic.ConvReLU3d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn1d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn2d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn3d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU1d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU2d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU3d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU2d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU3d)
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv1d))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv2d))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv3d))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv1d))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv2d))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv3d))
|
|
# just for error checks
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv1d))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv3d))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv3d))
|
|
# TODO: rename Relu -> ReLU to be more consistent with other classes
|
|
class ConvReluQuantizeHandler(QuantizeHandler):
|
|
def __init__(self, node: Node, modules: Dict[str, torch.nn.Module]):
|
|
super().__init__(node, modules)
|
|
self.relu_node = None
|
|
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
|
|
(node.op == 'call_module' and isinstance(modules[str(node.target)], torch.nn.ReLU)):
|
|
self.relu_node = node
|
|
node = node.args[0] # type: ignore[assignment]
|
|
self.conv_node = node
|
|
if node.op == "call_module":
|
|
self.conv = modules[str(self.conv_node.target)]
|
|
elif node.op == "call_function":
|
|
self.conv = node.target # type: ignore[assignment]
|
|
|
|
def convert(self,
|
|
node: Node,
|
|
qconfig: QConfigAny,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
# Supported combinations are:
|
|
# quant_type | activation (compute_type) | weight
|
|
# static quint8 qint8
|
|
|
|
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
|
supported_dtypes = [
|
|
(torch.quint8, torch.qint8, None),
|
|
]
|
|
|
|
# TODO: is_reference option for conv module
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
# leave the op unquantized if the dtype combination is not supported
|
|
if not is_reference and dtypes not in supported_dtypes:
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by Conv "
|
|
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
|
|
if self.relu_node:
|
|
conv_out = quantized_graph.node_copy(self.conv_node, load_arg(quantized=torch.float))
|
|
relu_args = [conv_out]
|
|
relu_args.extend(load_arg(quantized=torch.float)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=torch.float)(self.relu_node.kwargs)
|
|
return create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
("call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs),
|
|
self.relu_node)
|
|
else:
|
|
return quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
|
|
activation_int8_quantized = activation_is_int8_quantized(qconfig)
|
|
|
|
if self.conv_node.op == 'call_module':
|
|
# note that relu should already be fused into conv module in the fusion step
|
|
assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \
|
|
'please make sure to run fusion before prepare'
|
|
output_activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(modules)
|
|
assert output_activation_post_process is not None
|
|
|
|
module_types_supports_reference_pattern = [
|
|
torch.nn.Conv1d,
|
|
torch.nn.Conv2d,
|
|
torch.nn.Conv3d,
|
|
torch.nn.intrinsic.ConvReLU1d,
|
|
torch.nn.intrinsic.ConvReLU2d,
|
|
torch.nn.intrinsic.ConvReLU3d,
|
|
]
|
|
module_types_supports_reference_pattern.extend(list(QAT_CONV_MODULE_CLASSES))
|
|
# We'll always produce reference pattern for torch.nn.Conv*d,
|
|
# will remove the else branch after we migrated all use cases
|
|
if is_reference or \
|
|
type(self.conv) in module_types_supports_reference_pattern and \
|
|
dtypes in [(torch.quint8, torch.qint8, None)]:
|
|
# produce dequant - float_op - quant pattern
|
|
dtype = torch.float
|
|
if activation_int8_quantized:
|
|
dtype = activation_dtype(qconfig)
|
|
activation = load_arg(quantized=dtype)(self.conv_node.args[0])
|
|
args = load_arg(quantized=torch.float)(self.conv_node.args)
|
|
# Get the float conv and attach quantization scheme and quantization
|
|
# parameters of weight to the module
|
|
# and qparam is a dictionary of
|
|
# {"qscheme": ..., "scale": ..., "zero_point": ...} for per tensor quantization or
|
|
# {"qscheme": ..., "scale": ..., "zero_point": ..., "axis": ...} for per channel quantization
|
|
float_conv = self.conv
|
|
fused_conv = None
|
|
if isinstance(
|
|
float_conv,
|
|
QAT_CONV_MODULE_CLASSES):
|
|
# case 1. converting qat conv module to
|
|
# a float conv module, we need to attch
|
|
# weight fake_quant to the conv module,
|
|
# weight fake_quant is assumed to be run during
|
|
# QAT so we don't need to run it again here
|
|
float_conv = float_conv.to_float() # type: ignore[operator]
|
|
# change qat conv to conv
|
|
parent_name, name = _parent_name(self.conv_node.target)
|
|
setattr(modules[parent_name], name, float_conv)
|
|
if isinstance(float_conv, torch.nn.intrinsic._FusedModule):
|
|
fused_conv = float_conv
|
|
float_conv = fused_conv[0]
|
|
weight_post_process = self.conv.weight_fake_quant
|
|
else:
|
|
# case 2. converting a conv module/fused conv module
|
|
# to float conv module, we need to attach
|
|
# weight observer to the conv module and run it
|
|
# with conv weight
|
|
if isinstance(float_conv, torch.nn.intrinsic._FusedModule):
|
|
fused_conv = float_conv
|
|
float_conv = fused_conv[0] # type: ignore[index]
|
|
assert qconfig is not None
|
|
weight_post_process = qconfig.weight()
|
|
|
|
# return early when we don't have a valid match
|
|
# this typically happens when we called the same conv multiple times in the
|
|
# same graph, and it is transformed in previous steps into a reference conv already
|
|
if type(float_conv) not in [torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d]:
|
|
op_out = create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
('call_module', self.conv_node.target, args, {}),
|
|
self.conv_node)
|
|
return op_out
|
|
|
|
qconv_cls = get_static_quant_module_class(
|
|
type(float_conv), is_reference=True)
|
|
# run weight observer
|
|
# TODO: This is currently a hack for QAT to get the right shapes for scale and zero point.
|
|
# In the future, we should require the user to calibrate the model after calling prepare
|
|
weight_post_process(float_conv.weight) # type: ignore[operator]
|
|
weight_qparams = get_qparam_dict(weight_post_process)
|
|
# hardcoded for now, TODO: expose the api to user,
|
|
# we can have a map from module to reference module
|
|
# and allow user to register new ones
|
|
ref_conv = qconv_cls.from_float(float_conv, weight_qparams) # type: ignore[attr-defined]
|
|
# if the parent is a fused conv (Sequential), we can replace the first
|
|
# item to ref conv, otherwise we can update
|
|
# the conv instance in the module tree
|
|
if fused_conv is not None:
|
|
fused_conv[0] = ref_conv
|
|
parent_name, name = _parent_name(self.conv_node.target)
|
|
setattr(modules[parent_name], name, fused_conv)
|
|
else:
|
|
parent_name, name = _parent_name(self.conv_node.target)
|
|
setattr(modules[parent_name], name, ref_conv)
|
|
op_out = create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
('call_module', self.conv_node.target, args, {}),
|
|
self.conv_node)
|
|
if output_activation_post_process:
|
|
op_out = quantize_node(
|
|
op_out,
|
|
output_activation_post_process,
|
|
node,
|
|
modules,
|
|
quantized_graph,
|
|
node_name_to_scope,
|
|
is_input=False)
|
|
return op_out
|
|
else:
|
|
if convert_custom_config_dict is None:
|
|
convert_custom_config_dict = {}
|
|
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
|
|
# 1. attach activation post process to module
|
|
self.conv.activation_post_process = output_activation_post_process
|
|
# 2. select quantized class
|
|
qconv_cls = get_static_quant_module_class(
|
|
type(self.conv), additional_static_quant_mapping, is_reference=is_reference)
|
|
quantized = qconv_cls.from_float(self.conv)
|
|
parent_name, name = _parent_name(self.conv_node.target)
|
|
setattr(modules[parent_name], name, quantized)
|
|
return create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
(
|
|
'call_module',
|
|
self.conv_node.target,
|
|
(load_arg(quantized=torch.quint8)(self.conv_node.args[0]),),
|
|
{},
|
|
),
|
|
self.conv_node)
|
|
else: # call_function
|
|
assert self.conv_node.op == "call_function"
|
|
conv_functional_ops = {
|
|
torch.nn.functional.conv1d,
|
|
torch.nn.functional.conv2d,
|
|
torch.nn.functional.conv3d,
|
|
}
|
|
if is_reference or self.conv_node.target in conv_functional_ops and\
|
|
dtypes in [(torch.quint8, torch.qint8, None)]:
|
|
# make sure the input and weight are quantized to torch.quint8, torch.qint8, respectively
|
|
load_arg(quantized={0: torch.quint8, 1: torch.qint8})(self.conv_node.args)
|
|
args = load_arg(quantized=torch.float)(self.conv_node.args)
|
|
kwargs = load_arg(quantized=torch.float)(self.conv_node.kwargs)
|
|
op_out = create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
("call_function", self.conv, args, kwargs),
|
|
self.conv_node)
|
|
if self.relu_node:
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=torch.float)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=torch.float)(self.relu_node.kwargs)
|
|
op_out = create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
("call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs),
|
|
self.relu_node)
|
|
|
|
if activation_int8_quantized:
|
|
root_module = modules['']
|
|
act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name
|
|
act_post_process_node = self.relu_node if self.relu_node else self.conv_node
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(modules)
|
|
assert activation_post_process is not None
|
|
return quantize_node(
|
|
op_out,
|
|
activation_post_process,
|
|
act_post_process_node,
|
|
modules,
|
|
quantized_graph,
|
|
node_name_to_scope,
|
|
is_input=False)
|
|
else:
|
|
# output for dynamically quantized conv op is not quantized
|
|
return op_out
|
|
else:
|
|
assert len(self.conv_node.args) >= 7, \
|
|
"only conv2d calls with all arguments specified is supported right now in is_reference=False option"
|
|
# make sure the input and weight are quantized to torch.quint8, torch.qint8, respectively
|
|
args = load_arg(quantized={0: torch.quint8, 1: torch.qint8})(self.conv_node.args)
|
|
# pack weight
|
|
weight = load_arg(quantized=torch.qint8)(self.conv_node.args[1])
|
|
other_args = load_arg(quantized=torch.float)(self.conv_node.args[2:])
|
|
bias, stride, padding, dilation, groups = other_args
|
|
if self.conv == torch.nn.functional.conv1d:
|
|
# F.conv1d can take `int` as well as `list[int]` for stride,
|
|
# padding, dilation, but the prepack op cannot. Convert
|
|
# these to lists if needed.
|
|
stride = [stride] if isinstance(stride, int) else stride
|
|
padding = [padding] if isinstance(padding, int) else padding
|
|
dilation = [dilation] if isinstance(dilation, int) else dilation
|
|
prepack_args = (weight, bias, stride, padding, dilation, groups)
|
|
prepack_op = get_qconv_prepack_op(self.conv)
|
|
packed_weight = quantized_graph.create_node(
|
|
"call_function", prepack_op, prepack_args, {})
|
|
assert activation_int8_quantized, \
|
|
"currently only static quantization is supported for conv"
|
|
# construct conv input
|
|
if activation_int8_quantized:
|
|
qconv_op = get_qconv_op(self.conv, self.relu_node is not None)
|
|
conv_input = load_arg(quantized=torch.quint8)(self.conv_node.args[0])
|
|
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(modules)
|
|
assert activation_post_process is not None
|
|
|
|
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
|
|
scale_node, zero_point_node = \
|
|
create_qparam_nodes(
|
|
self.conv_node.name, scale, zero_point, modules,
|
|
quantized_graph, node_name_to_scope)
|
|
qconv_args = (conv_input, packed_weight, scale_node, zero_point_node)
|
|
kwargs = load_arg(quantized=torch.float)(self.conv_node.kwargs)
|
|
op = create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
('call_function', qconv_op, qconv_args, kwargs),
|
|
self.conv_node)
|
|
# Store the name of the fused op to get the path of node after fusion as well.
|
|
# TODO: may need to change the key to Node regenerate the map in each transformation,
|
|
# since we might not be able to rely on the name
|
|
node_name_to_scope[op.name] = node_name_to_scope[self.conv_node.name]
|
|
return op
|
|
else:
|
|
# conv2d_dyanmic branch
|
|
raise Exception("Only static quant is supported for conv")
|
|
|
|
@register_quant_pattern(torch.nn.Linear)
|
|
@register_quant_pattern(torch.nn.functional.linear)
|
|
@register_quant_pattern(torch.nn.qat.Linear)
|
|
@register_quant_pattern(torch.nn.intrinsic.LinearReLU)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.LinearReLU)
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.linear))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.linear))
|
|
@register_quant_pattern(torch.nn.intrinsic.LinearBn1d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.LinearBn1d)
|
|
# for error checks
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.Linear))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Linear))
|
|
class LinearReLUQuantizeHandler(QuantizeHandler):
|
|
def __init__(
|
|
self,
|
|
node: Node,
|
|
modules: Dict[str, torch.nn.Module]):
|
|
super().__init__(node, modules)
|
|
self.relu_node = None
|
|
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
|
|
(node.op == 'call_module' and isinstance(modules[str(node.target)], torch.nn.ReLU)):
|
|
self.relu_node = node
|
|
node = node.args[0] # type: ignore[assignment]
|
|
self.linear_node = node
|
|
if node.op == 'call_module':
|
|
self.linear = modules[str(self.linear_node.target)]
|
|
|
|
def convert(self,
|
|
node: Node,
|
|
qconfig: QConfigAny,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
if convert_custom_config_dict is None:
|
|
convert_custom_config_dict = {}
|
|
# Supported combinations are:
|
|
# quant_type | activation (compute_type) | weight
|
|
# static quint8 qint8
|
|
# dynamic float32 (quint8) qint8
|
|
# weight_only float32 float16
|
|
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
|
supported_dtypes = [
|
|
(torch.quint8, torch.qint8, None),
|
|
(torch.float32, torch.qint8, torch.quint8),
|
|
(torch.float32, torch.float16, None),
|
|
# static float16 quantization
|
|
(torch.float16, torch.float16, None),
|
|
]
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
# leave the op unquantized if the dtype combination is not supported
|
|
if not is_reference and dtypes not in supported_dtypes:
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by Linear "
|
|
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
|
|
if self.relu_node:
|
|
op_out = quantized_graph.node_copy(self.linear_node, load_arg(quantized=torch.float))
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=torch.float)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=torch.float)(self.relu_node.kwargs)
|
|
return create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
("call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs),
|
|
self.relu_node)
|
|
else:
|
|
return quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
activation_int8_quantized = activation_is_int8_quantized(qconfig)
|
|
activation_statically_quantized = activation_is_statically_quantized(qconfig)
|
|
weight_dtype = dtypes[1]
|
|
if self.linear_node.op == 'call_module':
|
|
|
|
output_activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(modules)
|
|
|
|
# note that relu should already be fused into linear modul in the fusion step
|
|
assert self.relu_node is None, 'linear module and relu fusion is not executed, ' \
|
|
'please make sure to run fusion before prepare'
|
|
# we'll always produce reference pattern for the following modules
|
|
# will remove the else branch after we migrated all use cases
|
|
module_allowlist = [
|
|
torch.nn.Linear,
|
|
torch.nn.qat.Linear,
|
|
torch.nn.intrinsic.LinearReLU,
|
|
torch.nn.intrinsic.qat.LinearReLU,
|
|
torch.nn.intrinsic.LinearBn1d,
|
|
torch.nn.intrinsic.qat.LinearBn1d,
|
|
]
|
|
if is_reference or type(self.linear) in module_allowlist and dtypes in [(torch.quint8, torch.qint8, None)]:
|
|
# produce dequant - float_op - quant pattern
|
|
dtype = torch.float
|
|
if activation_int8_quantized:
|
|
dtype = activation_dtype(qconfig)
|
|
activation = load_arg(quantized=dtype)(self.linear_node.args[0])
|
|
args = load_arg(quantized=torch.float)(self.linear_node.args)
|
|
|
|
# Get the float linear and attach qscheme and qparams the the module
|
|
float_linear = self.linear
|
|
fused_linear = None
|
|
qat_modules = (
|
|
torch.nn.qat.Linear,
|
|
torch.nn.intrinsic.qat.LinearReLU,
|
|
torch.nn.intrinsic.qat.LinearBn1d,
|
|
)
|
|
static_fused_modules = (
|
|
torch.nn.intrinsic.LinearReLU,
|
|
torch.nn.intrinsic.LinearBn1d,
|
|
)
|
|
if isinstance(float_linear, qat_modules):
|
|
float_linear = float_linear.to_float()
|
|
# change qat linear to linear
|
|
parent_name, name = _parent_name(self.linear_node.target)
|
|
setattr(modules[parent_name], name, float_linear)
|
|
# Attach weight fake quant to the linear module
|
|
if isinstance(float_linear, static_fused_modules):
|
|
fused_linear = float_linear
|
|
float_linear = float_linear[0]
|
|
weight_post_process = self.linear.weight_fake_quant
|
|
else:
|
|
if isinstance(float_linear, static_fused_modules):
|
|
fused_linear = float_linear
|
|
float_linear = self.linear[0] # type: ignore[index]
|
|
# Attach the weight observer to the module
|
|
weight_post_process = qconfig.weight() # type: ignore[union-attr]
|
|
|
|
# Run weight observer
|
|
# TODO: This is currently a hack for QAT to get the right shapes for scale and zero point.
|
|
# In the future, we should require the user to calibrate the model after calling prepare
|
|
weight_post_process(float_linear.weight) # type: ignore[operator]
|
|
|
|
weight_qparams = get_qparam_dict(weight_post_process)
|
|
# TODO: include the configuration in backend_config_dict
|
|
# we can have a map from module to reference module
|
|
# and allow user to register new ones
|
|
qlinear_cls = get_static_quant_module_class(
|
|
type(float_linear), is_reference=True)
|
|
ref_linear = qlinear_cls.from_float(float_linear, weight_qparams)
|
|
|
|
# if the parent is a fused linear (Sequential), we can replace the first
|
|
# item to ref linear, otherwise we can update
|
|
# the linear instance in the module tree
|
|
if fused_linear is not None:
|
|
fused_linear[0] = ref_linear
|
|
else:
|
|
parent_name, name = _parent_name(self.linear_node.target)
|
|
setattr(modules[parent_name], name, ref_linear)
|
|
op_out = create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
('call_module', self.linear_node.target, args, {}),
|
|
self.linear_node)
|
|
if output_activation_post_process:
|
|
op_out = quantize_node(
|
|
op_out,
|
|
output_activation_post_process,
|
|
node,
|
|
modules,
|
|
quantized_graph,
|
|
node_name_to_scope,
|
|
is_input=False)
|
|
return op_out
|
|
# non-reference option
|
|
else:
|
|
# 1. attach output activation post process to linear module
|
|
if output_activation_post_process:
|
|
self.linear.activation_post_process = output_activation_post_process
|
|
|
|
# 2. select corresponding quantized linear class for the float linear class
|
|
if activation_int8_quantized:
|
|
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
|
|
qlinear = get_static_quant_module_class(
|
|
type(self.linear), additional_static_quant_mapping)
|
|
else:
|
|
assert dtypes in [
|
|
(torch.float32, torch.qint8, torch.quint8),
|
|
(torch.float32, torch.float16, None),
|
|
], f"dtype {dtypes} not supported yet"
|
|
additional_dynamic_quant_mapping = convert_custom_config_dict.get("dynamic", {})
|
|
qlinear = get_dynamic_quant_module_class(type(self.linear), additional_dynamic_quant_mapping)
|
|
|
|
quantized = qlinear.from_float(self.linear)
|
|
parent_name, name = _parent_name(self.linear_node.target)
|
|
setattr(modules[parent_name], name, quantized)
|
|
# activation needs to be quantized for static quantization
|
|
dtype = torch.float
|
|
if activation_int8_quantized:
|
|
dtype = activation_dtype(qconfig)
|
|
return create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
(
|
|
'call_module',
|
|
self.linear_node.target,
|
|
(load_arg(quantized=dtype)(self.linear_node.args[0]),), {},
|
|
),
|
|
self.linear_node)
|
|
else: # call_function
|
|
assert self.linear_node.op == 'call_function'
|
|
if is_reference or self.linear_node.target == torch.nn.functional.linear and\
|
|
dtypes in [(torch.quint8, torch.qint8, None)]:
|
|
quantized_input_dtypes = [torch.float, torch.float]
|
|
if activation_int8_quantized:
|
|
quantized_input_dtypes[0] = torch.quint8
|
|
if weight_is_statically_quantized(qconfig):
|
|
quantized_input_dtypes[1] = torch.qint8
|
|
args = load_arg(quantized=quantized_input_dtypes)(self.linear_node.args)
|
|
args = load_arg(quantized=torch.float)(self.linear_node.args)
|
|
kwargs = load_arg(quantized=torch.float)(self.linear_node.kwargs)
|
|
op_out = create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
("call_function", torch.nn.functional.linear, args, kwargs),
|
|
self.linear_node)
|
|
if self.relu_node:
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=torch.float)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=torch.float)(self.relu_node.kwargs)
|
|
op_out = create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
("call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs),
|
|
self.relu_node)
|
|
|
|
if activation_statically_quantized:
|
|
# quantize output for statically quantized linear op
|
|
root_module = modules['']
|
|
act_post_process_name = self.relu_node.name if self.relu_node else self.linear_node.name
|
|
act_post_process_node = self.relu_node if self.relu_node else self.linear_node
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(modules)
|
|
assert activation_post_process is not None
|
|
return quantize_node(
|
|
op_out,
|
|
activation_post_process,
|
|
act_post_process_node,
|
|
modules,
|
|
quantized_graph,
|
|
node_name_to_scope,
|
|
is_input=False,
|
|
output_prefix="")
|
|
else:
|
|
# output for dynamically quantized linear op is not quantized
|
|
return op_out
|
|
else: # non-reference option
|
|
# prepacking weights for static int8 quant and dynamic quant
|
|
if dtypes != (torch.float16, torch.float16, None):
|
|
# linear args
|
|
# (x, weight, bias, ...)
|
|
# TODO: the name should be weight is int8 quantized
|
|
weight_quantized = weight_is_statically_quantized(qconfig)
|
|
dtype = weight_dtype if weight_quantized else torch.float
|
|
linear_weight = load_arg(quantized=dtype)(self.linear_node.args[1])
|
|
|
|
# get other arguments
|
|
kwargs = {**load_arg(quantized=torch.float)(self.linear_node.kwargs)}
|
|
# all args after bias, including bias
|
|
other_args = load_arg(quantized=torch.float)(self.linear_node.args[2:])
|
|
# bias might be either positional, or a keyword argument
|
|
if len(self.linear_node.args) > 2:
|
|
bias = load_arg(quantized=torch.float)(self.linear_node.args[2])
|
|
other_args = other_args[1:] # remove the bias argument
|
|
else:
|
|
bias = kwargs.pop('bias', None)
|
|
|
|
prepack_args = (linear_weight, bias)
|
|
prepack_op = get_linear_prepack_op_for_dtype(weight_dtype)
|
|
packed_weight = quantized_graph.create_node(
|
|
'call_function', prepack_op, prepack_args, {})
|
|
# construct linear input
|
|
if activation_int8_quantized:
|
|
qlinear_op = torch.ops.quantized.linear_relu if self.relu_node else torch.ops.quantized.linear
|
|
linear_input = load_arg(quantized=torch.quint8)(self.linear_node.args[0])
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(modules)
|
|
assert activation_post_process is not None
|
|
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
|
|
scale_node, zero_point_node = \
|
|
create_qparam_nodes(
|
|
self.linear_node.name, scale, zero_point, modules,
|
|
quantized_graph, node_name_to_scope)
|
|
|
|
qlinear_args = (linear_input, packed_weight, scale_node, zero_point_node)
|
|
op = create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
("call_function", qlinear_op, qlinear_args, kwargs),
|
|
self.linear_node)
|
|
# Store the name of the fused op to get the path of node after fusion as well.
|
|
# TODO: may need to change the key to Node regenerate the map in each transformation,
|
|
# since we might not be able to rely on the name
|
|
node_name_to_scope[op.name] = node_name_to_scope[self.linear_node.name]
|
|
return op
|
|
elif dtypes in [(torch.float32, torch.qint8, torch.quint8),
|
|
(torch.float32, torch.float16, None)]:
|
|
# choose linear dynamic or linear dynamic fp16 op based on weight dtype
|
|
if weight_dtype == torch.qint8:
|
|
if self.relu_node:
|
|
qlinear_op = torch.ops.quantized.linear_relu_dynamic
|
|
else:
|
|
qlinear_op = torch.ops.quantized.linear_dynamic
|
|
else:
|
|
if self.relu_node:
|
|
qlinear_op = torch.ops.quantized.linear_relu_dynamic_fp16
|
|
else:
|
|
qlinear_op = torch.ops.quantized.linear_dynamic_fp16
|
|
|
|
linear_input = load_arg(quantized=torch.float)(self.linear_node.args[0])
|
|
qlinear_args = (linear_input, packed_weight) # type: ignore[assignment]
|
|
op_out = create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
("call_function", qlinear_op, qlinear_args, kwargs),
|
|
self.linear_node)
|
|
# Store the name of the dynamic op to get the path of node after replacement as well.
|
|
# TODO: may need to change the key to Node regenerate the map in each transformation,
|
|
# since we might not be able to rely on the name
|
|
node_name_to_scope[op_out.name] = node_name_to_scope[self.linear_node.name]
|
|
return op_out
|
|
else:
|
|
assert dtypes == (torch.float16, torch.float16, None)
|
|
# TODO (refactor) this is duplicated, maybe have a helper function
|
|
if self.relu_node:
|
|
op_out = quantized_graph.node_copy(self.linear_node, load_arg(quantized=torch.float))
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=torch.float)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=torch.float)(self.relu_node.kwargs)
|
|
op_out = create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
("call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs),
|
|
self.relu_node)
|
|
else:
|
|
op_out = quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
return quantized_graph.create_node(
|
|
"call_method", "to", (op_out, torch.float16), {})
|
|
|
|
@register_quant_pattern(torch.nn.BatchNorm2d)
|
|
@register_quant_pattern(torch.nn.BatchNorm3d)
|
|
@register_quant_pattern(torch.nn.intrinsic.BNReLU2d)
|
|
@register_quant_pattern(torch.nn.intrinsic.BNReLU3d)
|
|
class BatchNormQuantizeHandler(QuantizeHandler):
|
|
def __init__(
|
|
self,
|
|
node: Node,
|
|
modules: Dict[str, torch.nn.Module]):
|
|
super().__init__(node, modules)
|
|
assert node.op == 'call_module'
|
|
self.bn_node = node
|
|
self.bn = modules[str(self.bn_node.target)]
|
|
|
|
def convert(self,
|
|
node: Node,
|
|
qconfig: QConfigAny,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
if convert_custom_config_dict is None:
|
|
convert_custom_config_dict = {}
|
|
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
|
|
# 1. attach activation post process to module
|
|
output_activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(modules)
|
|
assert output_activation_post_process is not None
|
|
if is_reference:
|
|
# produce dequant - float_op - quant pattern
|
|
dtype = activation_dtype(qconfig)
|
|
activation = load_arg(quantized=dtype)(self.bn_node.args[0])
|
|
args = load_arg(quantized=torch.float)(self.bn_node.args)
|
|
op_out = create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
("call_module", self.bn_node.target, args, {}),
|
|
self.bn_node)
|
|
if output_activation_post_process:
|
|
op_out = quantize_node(
|
|
op_out,
|
|
output_activation_post_process,
|
|
node,
|
|
modules,
|
|
quantized_graph,
|
|
node_name_to_scope,
|
|
is_input=False)
|
|
return op_out
|
|
else:
|
|
self.bn.activation_post_process = output_activation_post_process
|
|
qbn_cls = get_static_quant_module_class(type(self.bn), additional_static_quant_mapping)
|
|
quantized = qbn_cls.from_float(self.bn)
|
|
parent_name, name = _parent_name(self.bn_node.target)
|
|
setattr(modules[parent_name], name, quantized)
|
|
return create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
(
|
|
'call_module',
|
|
self.bn_node.target,
|
|
load_arg(quantized=[0])(self.bn_node.args),
|
|
load_arg(quantized=torch.float)(self.bn_node.kwargs),
|
|
),
|
|
self.bn_node)
|
|
|
|
@register_quant_pattern(torch.nn.qat.Embedding)
|
|
@register_quant_pattern(torch.nn.qat.EmbeddingBag)
|
|
@register_quant_pattern(torch.nn.Embedding)
|
|
@register_quant_pattern(torch.nn.EmbeddingBag)
|
|
class EmbeddingQuantizeHandler(QuantizeHandler):
|
|
def __init__(
|
|
self,
|
|
node: Node,
|
|
modules: Dict[str, torch.nn.Module]):
|
|
super().__init__(node, modules)
|
|
|
|
def input_output_observed(self) -> bool:
|
|
return False
|
|
|
|
def convert(self,
|
|
node: Node,
|
|
qconfig: QConfigAny,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
# Supported combinations are:
|
|
# quant_type | activation | weight | activation_compute_type
|
|
# weight_only | float32 | quint8 | None
|
|
# weight_only | float32 | quint4x2 | None
|
|
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
|
supported_dtypes = [
|
|
(torch.float32, torch.quint8, None),
|
|
(torch.float32, torch.quint4x2, None),
|
|
]
|
|
assert node.op == 'call_module'
|
|
emb_node = node
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
# leave the op unquantized if the dtype combination is not supported
|
|
if dtypes not in supported_dtypes:
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by Embedding/EmbeddingBag, "
|
|
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
|
|
return quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
emb = modules[str(emb_node.target)]
|
|
qemb = get_static_quant_module_class(type(emb))
|
|
quantized = qemb.from_float(emb)
|
|
parent_name, name = _parent_name(emb_node.target)
|
|
setattr(modules[parent_name], name, quantized)
|
|
return create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
(
|
|
'call_module',
|
|
emb_node.target,
|
|
load_arg(quantized=torch.float)(emb_node.args),
|
|
load_arg(quantized=torch.float)(emb_node.kwargs),
|
|
),
|
|
emb_node)
|
|
|
|
# TODO (maybe): merge with embedding quantize handler
|
|
@register_quant_pattern(torch.nn.GRUCell)
|
|
@register_quant_pattern(torch.nn.LSTMCell)
|
|
@register_quant_pattern(torch.nn.RNNCell)
|
|
@register_quant_pattern(torch.nn.LSTM)
|
|
class RNNDynamicQuantizeHandler(QuantizeHandler):
|
|
def __init__(
|
|
self,
|
|
node: Node,
|
|
modules: Dict[str, torch.nn.Module]):
|
|
super().__init__(node, modules)
|
|
|
|
def convert(self,
|
|
node: Node,
|
|
qconfig: QConfigAny,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
# Supported combinations are:
|
|
# quant_type | activation | weight | activation_compute_type
|
|
# dynamic | float32 | qint8 | quint8
|
|
# dynamic | float32 | float16 | None
|
|
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
|
supported_dtypes = [
|
|
(torch.float32, torch.qint8, torch.quint8),
|
|
(torch.float32, torch.float16, None),
|
|
]
|
|
assert node.op == 'call_module'
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
# leave the op unquantized if the dtype combination is not supported
|
|
if dtypes not in supported_dtypes:
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by Embedding/EmbeddingBag, "
|
|
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
|
|
return quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
act_dtype, weight_dtype, compute_dtype = dtypes
|
|
activation = load_arg(quantized=act_dtype)(node.args[0])
|
|
module = modules[str(node.target)]
|
|
qmodule_cls = get_dynamic_quant_module_class(type(module))
|
|
qmodule = qmodule_cls.from_float(module)
|
|
parent_name, name = _parent_name(node.target)
|
|
setattr(modules[parent_name], name, qmodule)
|
|
return create_node_from_old_node_preserve_meta(
|
|
quantized_graph,
|
|
(
|
|
'call_module',
|
|
node.target,
|
|
load_arg(quantized=torch.float)(node.args),
|
|
load_arg(quantized=torch.float)(node.kwargs),
|
|
),
|
|
node)
|
|
|
|
ARGS_TO_SKIP = {
|
|
torch._ops.ops.quantized.hardswish: ['inplace'],
|
|
torch._ops.ops.quantized.elu: ['inplace'],
|
|
torch._ops.ops.quantized.dropout: ['inplace'],
|
|
torch._ops.ops.quantized.instance_norm:
|
|
['running_mean', 'running_var', 'use_input_stats', 'momentum'],
|
|
}
|
|
@register_quant_pattern(torch.nn.ConvTranspose1d)
|
|
@register_quant_pattern(torch.nn.ConvTranspose2d)
|
|
@register_quant_pattern(torch.nn.ELU)
|
|
@register_quant_pattern(torch.nn.LeakyReLU)
|
|
@register_quant_pattern(torch.nn.Hardswish)
|
|
@register_quant_pattern(torch.nn.InstanceNorm1d)
|
|
@register_quant_pattern(torch.nn.InstanceNorm2d)
|
|
@register_quant_pattern(torch.nn.InstanceNorm3d)
|
|
@register_quant_pattern(torch.nn.LayerNorm)
|
|
@register_quant_pattern(torch.nn.SiLU)
|
|
@register_quant_pattern(torch.nn.Mish)
|
|
@register_quant_pattern(torch.nn.Dropout)
|
|
# we currently only support reference patterns for these ops so they have been removed
|
|
# until they receive a proper fp16 kernel. To use the reference pattern, use a custom qconfig
|
|
# @register_quant_pattern(torch.nn.GELU)
|
|
# @register_quant_pattern(torch.nn.Softmax)
|
|
@register_quant_pattern(torch.nn.functional.elu)
|
|
@register_quant_pattern(torch.nn.functional.hardswish)
|
|
@register_quant_pattern(torch.nn.functional.instance_norm)
|
|
@register_quant_pattern(torch.nn.functional.layer_norm)
|
|
@register_quant_pattern(torch.nn.functional.leaky_relu)
|
|
@register_quant_pattern(torch.nn.functional.silu)
|
|
@register_quant_pattern(torch.nn.functional.mish)
|
|
@register_quant_pattern(torch.nn.functional.dropout)
|
|
# we currently only support reference patterns for these ops so they have been removed
|
|
# until they receive a proper fp16 kernel. To use the reference pattern, use a custom qconfig
|
|
# @register_quant_pattern(torch.nn.functional.gelu)
|
|
# @register_quant_pattern(torch.nn.functional.softmax)
|
|
@register_quant_pattern(torch.sum)
|
|
class DefaultNodeQuantizeHandler(QuantizeHandler):
|
|
""" Common quantized op, first input and first output will be quantized
|
|
"""
|
|
def __init__(
|
|
self,
|
|
node: Node,
|
|
modules: Dict[str, torch.nn.Module]):
|
|
super().__init__(node, modules)
|
|
if node.op == "call_function" or node.op == "call_method":
|
|
self.op = node.target
|
|
elif node.op == "call_module":
|
|
self.op = type(modules[str(node.target)])
|
|
|
|
def is_output_quantized(self, qconfig):
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
return self.op in default_op_supported_dtypes and \
|
|
dtypes in default_op_supported_dtypes[self.op]
|
|
|
|
def convert(self,
|
|
node: Node,
|
|
qconfig: QConfigAny,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
if not self.all_node_args_are_tensors:
|
|
return NotImplemented
|
|
assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \
|
|
'call_function are handled in DefaultNode'
|
|
if convert_custom_config_dict is None:
|
|
convert_custom_config_dict = {}
|
|
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
|
|
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
if not is_reference and dtypes not in default_op_supported_dtypes[self.op]:
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by {} "
|
|
"supported dtype combinations are: {}".format(dtypes, self.op, default_op_supported_dtypes[self.op]))
|
|
return quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
|
|
# We can produce reference for a dtypes including
|
|
# (torch.quint8, torch.qint8, torch.qint32, torch.float16)
|
|
act_dtype = activation_dtype(qconfig)
|
|
if act_dtype == torch.float:
|
|
op_out = quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
return op_out
|
|
else:
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(modules)
|
|
assert activation_post_process is not None
|
|
# make sure the input is quantized to act_dtype
|
|
load_arg(quantized={0: act_dtype})(node.args)
|
|
args = load_arg(quantized=torch.float)(node.args)
|
|
kwargs = load_arg(quantized=torch.float)(node.kwargs)
|
|
# swap float module to reference module (ConvTranspose)
|
|
float_module = modules[str(node.target)] if node.op == "call_module" else None
|
|
if type(float_module) in [torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d]:
|
|
ref_module_cls = get_static_quant_module_class(type(float_module), is_reference=True)
|
|
|
|
weight_post_process = qconfig.weight() # type: ignore[union-attr]
|
|
weight_post_process(float_module.weight) # type: ignore[union-attr]
|
|
weight_qparams = get_qparam_dict(weight_post_process)
|
|
ref_module = ref_module_cls.from_float(float_module, weight_qparams) # type: ignore[attr-defined]
|
|
parent_name, name = _parent_name(node.target)
|
|
setattr(modules[parent_name], name, ref_module)
|
|
op_out = quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
return quantize_node(
|
|
op_out, activation_post_process,
|
|
node, modules, quantized_graph, node_name_to_scope, is_input=False)
|
|
|
|
@register_quant_pattern(torch.nn.Hardsigmoid, default_affine_fixed_qparams_observer)
|
|
@register_quant_pattern(torch.nn.functional.hardsigmoid, default_affine_fixed_qparams_observer)
|
|
@register_quant_pattern('hardsigmoid', default_affine_fixed_qparams_observer)
|
|
@register_quant_pattern('hardsigmoid_', default_affine_fixed_qparams_observer)
|
|
@register_quant_pattern(torch.nn.Sigmoid, default_affine_fixed_qparams_observer)
|
|
@register_quant_pattern(torch.sigmoid, default_affine_fixed_qparams_observer)
|
|
@register_quant_pattern('sigmoid', default_affine_fixed_qparams_observer)
|
|
@register_quant_pattern('sigmoid_', default_affine_fixed_qparams_observer)
|
|
@register_quant_pattern(torch.nn.Tanh, default_symmetric_fixed_qparams_observer)
|
|
@register_quant_pattern(torch.tanh, default_symmetric_fixed_qparams_observer)
|
|
@register_quant_pattern('tanh', default_symmetric_fixed_qparams_observer)
|
|
@register_quant_pattern('tanh_', default_symmetric_fixed_qparams_observer)
|
|
class FixedQParamsOpQuantizeHandler(QuantizeHandler):
|
|
def __init__(self,
|
|
node: Node,
|
|
modules: Dict[str, torch.nn.Module]):
|
|
super().__init__(node, modules)
|
|
self.node = node
|
|
|
|
def should_mark_output_quantized_from_input_quantized_status(
|
|
self,
|
|
qconfig: QConfigAny
|
|
) -> bool:
|
|
# FixQParamOps are the same as CopyNode in int8 quantization
|
|
return activation_dtype(qconfig) in [torch.quint8, torch.qint8]
|
|
|
|
# some qhandlers override the activations constructor
|
|
def get_activation_ctr(self, qconfig, pattern, is_training) -> Optional[Callable]:
|
|
act_dtype = activation_dtype(qconfig)
|
|
if act_dtype == torch.quint8:
|
|
return get_default_output_activation_post_process_map(is_training).get(
|
|
pattern, qconfig.activation)
|
|
else:
|
|
return qconfig.activation
|
|
|
|
def convert(self,
|
|
node: Node,
|
|
qconfig: QConfigAny,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
act_dtype = activation_dtype(qconfig)
|
|
if act_dtype == torch.float:
|
|
op_out = quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
return op_out
|
|
else:
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(modules)
|
|
assert activation_post_process is not None
|
|
# make sure the input is quantized to act_dtype
|
|
load_arg(quantized={0: act_dtype})(node.args)
|
|
args = load_arg(quantized=torch.float)(node.args)
|
|
kwargs = load_arg(quantized=torch.float)(node.kwargs)
|
|
op_out = quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
return quantize_node(
|
|
op_out, activation_post_process,
|
|
node, modules, quantized_graph, node_name_to_scope, is_input=False)
|
|
|
|
@register_quant_pattern(torch.nn.AdaptiveAvgPool1d)
|
|
@register_quant_pattern(torch.nn.AdaptiveAvgPool2d)
|
|
@register_quant_pattern(torch.nn.AdaptiveAvgPool3d)
|
|
@register_quant_pattern(torch.nn.AvgPool1d)
|
|
@register_quant_pattern(torch.nn.AvgPool2d)
|
|
@register_quant_pattern(torch.nn.AvgPool3d)
|
|
@register_quant_pattern(torch.nn.Hardtanh)
|
|
@register_quant_pattern(torch.nn.MaxPool1d)
|
|
@register_quant_pattern(torch.nn.MaxPool2d)
|
|
@register_quant_pattern(torch.nn.MaxPool3d)
|
|
@register_quant_pattern(torch.nn.ReLU)
|
|
@register_quant_pattern(torch.nn.ReLU6)
|
|
@register_quant_pattern(torch.adaptive_avg_pool1d)
|
|
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool2d)
|
|
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool3d)
|
|
@register_quant_pattern(torch.nn.functional.hardtanh)
|
|
@register_quant_pattern(torch.nn.functional.hardtanh_)
|
|
@register_quant_pattern(torch.nn.functional.interpolate)
|
|
@register_quant_pattern(torch.nn.functional.max_pool1d)
|
|
@register_quant_pattern(torch.nn.functional.max_pool2d)
|
|
@register_quant_pattern(torch.nn.functional.max_pool3d)
|
|
@register_quant_pattern(torch.nn.functional.relu)
|
|
@register_quant_pattern(torch.nn.functional.relu6)
|
|
@register_quant_pattern(torch.avg_pool1d)
|
|
@register_quant_pattern(torch._C._nn.avg_pool2d)
|
|
@register_quant_pattern(torch._C._nn.avg_pool3d)
|
|
@register_quant_pattern(torch.clamp)
|
|
@register_quant_pattern(torch.flatten)
|
|
@register_quant_pattern(torch.mean)
|
|
@register_quant_pattern(operator.floordiv)
|
|
@register_quant_pattern('clamp')
|
|
@register_quant_pattern('mean')
|
|
@register_quant_pattern('relu')
|
|
@register_quant_pattern('relu_')
|
|
class CopyNodeQuantizeHandler(QuantizeHandler):
|
|
""" Operators that works on both float and quantized input
|
|
if input is quantized, the output Tensor shares
|
|
the same quantization parameter with input.
|
|
These ops will do computation on the input Tensor, e.g. average pool, so we will
|
|
insert extra observer/fake_quant for the output of these operators.
|
|
TODO: maybe rename this to TensorValueOpQuantizeHandler
|
|
"""
|
|
def should_mark_output_quantized_from_input_quantized_status(
|
|
self,
|
|
qconfig: QConfigAny
|
|
) -> bool:
|
|
return True
|
|
|
|
def is_general_tensor_value_op(self) -> bool:
|
|
return True
|
|
|
|
def convert(self,
|
|
node: Node,
|
|
qconfig: QConfigAny,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
|
|
# when activation dtype is torch.float, the node does not require
|
|
# observation
|
|
# e.g. dynamic quantization or weight_only quantization
|
|
act_dtype = activation_dtype(qconfig)
|
|
if act_dtype == torch.float:
|
|
op_out = quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
return op_out
|
|
else:
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(modules)
|
|
if activation_post_process is not None:
|
|
# make sure the input is quantized to act_dtype
|
|
load_arg(quantized={0: act_dtype})(node.args)
|
|
args = list(load_arg(quantized=torch.float)(node.args))
|
|
kwargs = load_arg(quantized=torch.float)(node.kwargs)
|
|
op_out = quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
return quantize_node(
|
|
op_out,
|
|
activation_post_process,
|
|
node, modules, quantized_graph, node_name_to_scope, is_input=False)
|
|
else:
|
|
op_out = quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
return op_out
|
|
|
|
class CustomModuleQuantizeHandler(QuantizeHandler):
|
|
def convert(self,
|
|
node: Node,
|
|
qconfig: QConfigAny,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
""" Convert a float custom module to quantized custom module
|
|
"""
|
|
assert node.op == 'call_module'
|
|
assert convert_custom_config_dict is not None
|
|
custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", None)
|
|
assert custom_module_class_mapping is not None
|
|
observed_custom_module = modules[str(node.target)]
|
|
if activation_is_statically_quantized(qconfig):
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(modules)
|
|
assert activation_post_process is not None
|
|
observed_custom_module.activation_post_process = activation_post_process
|
|
quantized_custom_module_class = get_swapped_custom_module_class(
|
|
observed_custom_module, custom_module_class_mapping, qconfig)
|
|
quantized_custom_module = \
|
|
quantized_custom_module_class.from_observed(observed_custom_module)
|
|
parent_name, name = _parent_name(node.target)
|
|
setattr(modules[parent_name], name, quantized_custom_module)
|
|
# hardcoded the quntized input to be None (take whatever is in the environemnt),
|
|
# we can extend this
|
|
# if there is a need, e.g. get the indexes of quantized inputs from some
|
|
# module attribute like module._QUANTIZED_INPUT_INDEXES
|
|
return quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
@register_quant_pattern(torch.nn.Identity)
|
|
@register_quant_pattern(torch.transpose)
|
|
@register_quant_pattern(torch.repeat_interleave)
|
|
@register_quant_pattern(torch.squeeze)
|
|
@register_quant_pattern(torch.stack)
|
|
@register_quant_pattern(torch.unsqueeze)
|
|
@register_quant_pattern('contiguous')
|
|
@register_quant_pattern('detach')
|
|
@register_quant_pattern('detach_')
|
|
@register_quant_pattern('permute')
|
|
@register_quant_pattern('repeat')
|
|
@register_quant_pattern('repeat_interleave')
|
|
@register_quant_pattern('reshape')
|
|
@register_quant_pattern('resize_')
|
|
@register_quant_pattern('shape')
|
|
@register_quant_pattern('size')
|
|
@register_quant_pattern('squeeze')
|
|
@register_quant_pattern('squeeze_')
|
|
@register_quant_pattern('transpose')
|
|
@register_quant_pattern('unsqueeze')
|
|
@register_quant_pattern('unsqueeze_')
|
|
@register_quant_pattern('view')
|
|
class GeneralTensorShapeOpQuantizeHandler(QuantizeHandler):
|
|
""" Operators that works on both float and quantized input
|
|
if input is quantized, the output Tensor shares
|
|
the same quantization parameter with input.
|
|
These ops only do rearrangement of Tensor values, for
|
|
example reshape, or just query the information about Tensor
|
|
e.g. size, and we do not insert extra observer/fake_quant
|
|
for the output of the operator.
|
|
"""
|
|
def is_general_tensor_shape_op(self) -> bool:
|
|
return True
|
|
|
|
def should_mark_output_quantized_from_input_quantized_status(
|
|
self,
|
|
qconfig: QConfigAny
|
|
) -> bool:
|
|
return True
|
|
|
|
def convert(self,
|
|
node: Node,
|
|
qconfig: QConfigAny,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
# when activation dtype is torch.float, the node does not require
|
|
# observation
|
|
# e.g. dynamic quantization or weight_only quantization
|
|
act_dtype = activation_dtype(qconfig)
|
|
if act_dtype == torch.float:
|
|
op_out = quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
return op_out
|
|
else:
|
|
activation_post_process = \
|
|
self._maybe_get_last_node_only_observer(modules)
|
|
if activation_post_process is not None:
|
|
args = list(load_arg(quantized=torch.float)(node.args))
|
|
kwargs = load_arg(quantized=torch.float)(node.kwargs)
|
|
op_out = quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
return quantize_node(
|
|
op_out,
|
|
activation_post_process,
|
|
node, modules, quantized_graph, node_name_to_scope, is_input=False)
|
|
else:
|
|
return quantized_graph.node_copy(node, load_arg(quantized=torch.float))
|
|
|
|
class StandaloneModuleQuantizeHandler(QuantizeHandler):
|
|
""" Converts an observed standalone module to quantized standalone module
|
|
by calling convert_fx on the observed standalone module.
|
|
"""
|
|
def convert(self,
|
|
node: Node,
|
|
qconfig: QConfigAny,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
assert node.op == 'call_module'
|
|
convert = torch.ao.quantization.quantize_fx._convert_standalone_module_fx # type: ignore[attr-defined]
|
|
# We know that observed standalone module is a GraphModule since
|
|
# it's produced by us
|
|
observed_standalone_module : GraphModule = modules[str(node.target)] # type: ignore[assignment]
|
|
input_quantized_idxs = observed_standalone_module._standalone_module_input_quantized_idxs.tolist() # type: ignore[operator]
|
|
quantized_standalone_module = convert(observed_standalone_module, is_reference=is_reference)
|
|
parent_name, name = _parent_name(node.target)
|
|
# update the modules dict
|
|
setattr(modules[parent_name], name, quantized_standalone_module)
|
|
modules[str(node.target)] = quantized_standalone_module
|
|
return quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs))
|