mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57470 Removes the earlier hack of matching patterns originally matched to BinaryOpQuantizeHandler to switch to CopyHandler. After this PR, each pattern can only be matched to one type of QuantizeHandler or to nothing. Test Plan: ``` python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D28152909 fbshipit-source-id: afc285e770bd7eb0518c90e3ee4874c421e78bbc
1286 lines
64 KiB
Python
1286 lines
64 KiB
Python
import torch
|
|
from torch.fx.graph import (
|
|
Node,
|
|
)
|
|
import torch.nn.quantized as nnq
|
|
import torch.nn.quantized.dynamic as nnqd
|
|
from torch.quantization import (
|
|
default_affine_fixed_qparams_fake_quant,
|
|
default_symmetric_fixed_qparams_fake_quant,
|
|
)
|
|
|
|
from ..quantization_mappings import (
|
|
get_static_quant_module_class,
|
|
get_dynamic_quant_module_class,
|
|
get_quantized_operator,
|
|
)
|
|
from ..utils import (
|
|
get_swapped_custom_module_class,
|
|
activation_is_statically_quantized,
|
|
activation_is_int8_quantized,
|
|
weight_is_statically_quantized,
|
|
get_qconfig_dtypes,
|
|
activation_dtype,
|
|
)
|
|
|
|
from .pattern_utils import (
|
|
register_quant_pattern,
|
|
get_default_output_activation_post_process_map,
|
|
Pattern,
|
|
)
|
|
|
|
from .utils import (
|
|
_parent_name,
|
|
all_node_args_have_no_tensors,
|
|
quantize_node,
|
|
get_per_tensor_qparams,
|
|
get_linear_prepack_op_for_dtype,
|
|
create_qparam_nodes,
|
|
get_qconv_prepack_op,
|
|
get_qconv_op,
|
|
)
|
|
|
|
from .quantization_types import QuantizerCls
|
|
|
|
from abc import ABC, abstractmethod
|
|
import operator
|
|
import warnings
|
|
|
|
from typing import Any, Callable, Dict, Union, Optional, Tuple, List, Set
|
|
|
|
# -------------------------
|
|
# Pattern Registrations
|
|
# -------------------------
|
|
|
|
# 1. Post Training Static Quantization and Quantization Aware Training Patterns
|
|
|
|
# Base Pattern Handler
|
|
class QuantizeHandler(ABC):
|
|
""" Base handler class for the quantizer patterns
|
|
"""
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
""" Records pattern information in __init__, which will be used
|
|
in convert
|
|
"""
|
|
# this is an indicator of whether all the inputs are Node or not
|
|
# since some op might be quantized differently depending on whether
|
|
# all inputs are tensors or not, e.g. add/mul
|
|
self.num_tensor_args = len(node.args)
|
|
self.all_node_args_are_tensors = True
|
|
|
|
def input_output_observed(self) -> bool:
|
|
"""
|
|
Returns True if the pattern matched to this qhandler could be
|
|
be observed, and False it it should not be observed.
|
|
"""
|
|
return True
|
|
|
|
def should_insert_observer_for_output(
|
|
self,
|
|
qconfig: Any,
|
|
model_is_training: bool,
|
|
) -> bool:
|
|
"""
|
|
Returns true if an observer should be inserted for the output of
|
|
the pattern matched to this QuantizeHandler instance during the
|
|
prepare step.
|
|
"""
|
|
# TODO(future PR): potentially clean up and deduplicate these
|
|
# mappings.
|
|
return self.all_node_args_are_tensors and self.input_output_observed()
|
|
|
|
def should_mark_output_observed_from_input_observed_status(
|
|
self,
|
|
observed_node_names_set: Set[str],
|
|
) -> bool:
|
|
"""
|
|
Returns true if the output of this pattern instance should be marked
|
|
as observed based on the observed status of inputs to this pattern.
|
|
"""
|
|
return False
|
|
|
|
def should_mark_output_quantized_from_input_quantized_status(
|
|
self,
|
|
) -> bool:
|
|
"""
|
|
Returns true if after convert, the output of the matched pattern is
|
|
quantized iff the first input is also quantized.
|
|
"""
|
|
return False
|
|
|
|
def get_activation_ctr(
|
|
self,
|
|
qconfig: Any,
|
|
pattern: Pattern,
|
|
) -> Optional[Callable]:
|
|
"""
|
|
Returns the constructor for the activation observer which should be
|
|
used for the pattern matched to this handler. Some handlers override
|
|
this to a different value than what is specified in the qconfig.
|
|
"""
|
|
return qconfig.activation
|
|
|
|
|
|
@abstractmethod
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
""" Convert the given node to a quantized node and insert
|
|
it to the quantized graph
|
|
"""
|
|
return NotImplemented
|
|
|
|
|
|
# Binary op configs
|
|
|
|
# Supported combinations are:
|
|
# quant_type | activation (compute_type) | weight
|
|
# static quint8 qint8
|
|
|
|
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
|
# these are supported types for common binary ops like add/mul etc.
|
|
binary_op_all_dtypes = [
|
|
(torch.quint8, torch.qint8, None),
|
|
(torch.float16, torch.float16, None),
|
|
]
|
|
binary_op_float16_dtypes = [
|
|
(torch.float16, torch.float16, None)
|
|
]
|
|
binary_op_int8_dtypes = [
|
|
(torch.quint8, torch.qint8, None),
|
|
]
|
|
binary_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = {
|
|
operator.add: binary_op_all_dtypes,
|
|
torch.add: binary_op_all_dtypes,
|
|
operator.mul: binary_op_all_dtypes,
|
|
torch.mul: binary_op_all_dtypes,
|
|
torch.bmm: binary_op_float16_dtypes,
|
|
torch.sub: binary_op_float16_dtypes,
|
|
operator.sub: binary_op_float16_dtypes,
|
|
torch.div: binary_op_float16_dtypes,
|
|
operator.truediv: binary_op_float16_dtypes,
|
|
torch.sum: binary_op_float16_dtypes
|
|
}
|
|
binary_reference_op_supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = {
|
|
torch.bmm: binary_op_int8_dtypes,
|
|
}
|
|
|
|
|
|
@register_quant_pattern(operator.add)
|
|
@register_quant_pattern(operator.sub)
|
|
@register_quant_pattern(operator.mul)
|
|
@register_quant_pattern(operator.truediv)
|
|
@register_quant_pattern(torch.add)
|
|
@register_quant_pattern(torch.sub)
|
|
@register_quant_pattern(torch.mul)
|
|
@register_quant_pattern(torch.div)
|
|
@register_quant_pattern(torch.sum)
|
|
@register_quant_pattern(torch.bmm)
|
|
@register_quant_pattern((torch.nn.ReLU, operator.add))
|
|
@register_quant_pattern((torch.nn.ReLU, operator.mul))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.add))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.mul))
|
|
@register_quant_pattern((torch.nn.functional.relu, operator.add))
|
|
@register_quant_pattern((torch.nn.functional.relu, operator.mul))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.add))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.mul))
|
|
class BinaryOpQuantizeHandler(QuantizeHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
self.relu_node = None
|
|
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
|
|
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
|
|
self.relu_node = node
|
|
node = node.args[0] # type: ignore[assignment]
|
|
self.binary_op_node = node
|
|
self.binary_op = node.target
|
|
|
|
# determine how many of the first two args are Tensors (versus scalars)
|
|
# this distinguishes things like "x + y" from "x + 2" or "2 + x"
|
|
self.num_tensor_args = 0
|
|
cache_for_no_tensor_check: Dict[Node, bool] = dict()
|
|
for arg_idx in range(len(self.binary_op_node.args)):
|
|
arg = self.binary_op_node.args[arg_idx]
|
|
if isinstance(arg, Node) and (not all_node_args_have_no_tensors(arg, quantizer.modules, cache_for_no_tensor_check)):
|
|
self.num_tensor_args += 1
|
|
self.all_node_args_are_tensors = \
|
|
(self.num_tensor_args == len(self.binary_op_node.args))
|
|
|
|
qbin_op_mapping: Dict[Union[Callable, str], Callable] = {
|
|
operator.add: torch.ops.quantized.add,
|
|
torch.add: torch.ops.quantized.add,
|
|
operator.mul: torch.ops.quantized.mul,
|
|
torch.mul: torch.ops.quantized.mul,
|
|
}
|
|
qbin_relu_op_mapping: Dict[Union[Callable, str], Callable] = {
|
|
operator.add: torch.ops.quantized.add_relu,
|
|
torch.add: torch.ops.quantized.add_relu,
|
|
operator.mul: torch.ops.quantized.mul_relu,
|
|
torch.mul: torch.ops.quantized.mul_relu,
|
|
}
|
|
# corresponding quantized op
|
|
self.quantized_binary_op: Optional[Callable] = None
|
|
if self.binary_op in qbin_op_mapping:
|
|
self.quantized_binary_op = qbin_relu_op_mapping[self.binary_op] \
|
|
if self.relu_node is not None \
|
|
else qbin_op_mapping[self.binary_op]
|
|
|
|
def should_insert_observer_for_output(
|
|
self,
|
|
qconfig: Any,
|
|
model_is_training: bool,
|
|
) -> bool:
|
|
"""
|
|
Returns true if an observer should be inserted for the output of
|
|
the pattern matched to this QuantizeHandler instance during the
|
|
prepare step.
|
|
"""
|
|
if self.num_tensor_args == 1:
|
|
return activation_dtype(qconfig) == torch.float16
|
|
elif self.all_node_args_are_tensors and self.input_output_observed():
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def should_mark_output_observed_from_input_observed_status(
|
|
self,
|
|
observed_node_names_set: Set[str],
|
|
) -> bool:
|
|
if self.num_tensor_args == 1:
|
|
# If only one of the inputs is a tensor, the output is
|
|
# observed if the tensor input is observed
|
|
def input_is_observed(arg):
|
|
return (isinstance(arg, Node) and
|
|
arg.name in observed_node_names_set)
|
|
# This is checking if one of the argument of add/mul
|
|
# is an observed node
|
|
# If both of the inputs are number,
|
|
# we will not consider the output to be observed
|
|
return (
|
|
input_is_observed(self.binary_op_node.args[0]) or
|
|
input_is_observed(self.binary_op_node.args[1])
|
|
)
|
|
else:
|
|
# If either none or both inputs are tensors, this code
|
|
# path will not be hit.
|
|
return False
|
|
|
|
def input_output_observed(self):
|
|
# for x + y where x and y are scalars, we do not observe anything
|
|
return self.num_tensor_args > 0
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
|
|
if self.num_tensor_args == 0:
|
|
# example: x + y, when x and y are scalars
|
|
return quantizer.quantized_graph.node_copy(
|
|
node, load_arg(quantized=None))
|
|
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
|
|
if is_reference and self.binary_op in binary_reference_op_supported_dtypes and \
|
|
dtypes in binary_reference_op_supported_dtypes[self.binary_op]:
|
|
if dtypes in binary_op_int8_dtypes:
|
|
args = load_arg(quantized=[0, 1])(node.args)
|
|
args = load_arg(quantized=False)(node.args)
|
|
kwargs = load_arg(quantized=False)(node.kwargs)
|
|
op_out = quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
|
activation_post_process = \
|
|
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
|
quantizer.activation_post_process_indexes[node.name] += 1
|
|
return quantize_node(
|
|
quantizer, op_out, activation_post_process,
|
|
node, is_input=False)
|
|
else:
|
|
warnings.warn(
|
|
"No implementation found for dtype combination: {}"
|
|
"for op {} with is_reference={} despite it being listed as supported"
|
|
"this should not happen".format(dtypes, self.binary_op, is_reference))
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
elif not is_reference and self.binary_op in binary_op_supported_dtypes and \
|
|
dtypes in binary_op_supported_dtypes[self.binary_op]:
|
|
if dtypes in [(torch.quint8, torch.qint8, None)]:
|
|
assert self.quantized_binary_op is not None
|
|
if self.num_tensor_args == 1:
|
|
# add/mul scalar
|
|
first_arg = self.binary_op_node.args[0]
|
|
cache_for_no_tensor_check: Dict[Node, bool] = dict()
|
|
if isinstance(first_arg, Node) and (
|
|
not all_node_args_have_no_tensors(
|
|
first_arg, quantizer.modules, cache_for_no_tensor_check)):
|
|
quantized_index = 0
|
|
else:
|
|
quantized_index = 1
|
|
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_function', self.quantized_binary_op,
|
|
load_arg(quantized=[quantized_index])(self.binary_op_node.args), self.binary_op_node.kwargs)
|
|
else:
|
|
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
|
activation_post_process = \
|
|
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
|
quantizer.activation_post_process_indexes[node.name] += 1
|
|
scale, zero_point = activation_post_process.calculate_qparams()
|
|
scale = float(scale)
|
|
zero_point = int(zero_point)
|
|
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
|
|
|
|
if self.relu_node is not None:
|
|
op = torch.ops.quantized.add_relu
|
|
else:
|
|
op = torch.ops.quantized.add
|
|
kwargs = {**self.binary_op_node.kwargs}
|
|
add_args = (*load_arg(quantized=True)(self.binary_op_node.args), scale_arg, zero_point_arg)
|
|
op = quantizer.quantized_graph.create_node(
|
|
'call_function', self.quantized_binary_op, add_args, kwargs)
|
|
return op
|
|
else:
|
|
assert dtypes == (torch.float16, torch.float16, None)
|
|
# TODO (refactor) this is duplicated, maybe have a helper function
|
|
if self.relu_node:
|
|
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
|
return quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
|
else:
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
else:
|
|
# leave the op unquantized if the dtype,reference combination is not supported
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by {} for is_reference={}. "
|
|
"Supported non-reference dtype combinations are: {} "
|
|
"Supported reference dtype combinations are: {}"
|
|
"".format(dtypes,
|
|
self.binary_op,
|
|
is_reference,
|
|
binary_op_supported_dtypes[self.binary_op],
|
|
(
|
|
[] if self.binary_op not in binary_reference_op_supported_dtypes.keys()
|
|
else binary_reference_op_supported_dtypes[self.binary_op]
|
|
)
|
|
)
|
|
)
|
|
if self.relu_node:
|
|
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
|
return quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
|
else:
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
|
|
|
|
@register_quant_pattern(torch.cat)
|
|
class CatQuantizeHandler(QuantizeHandler):
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
if not self.all_node_args_are_tensors:
|
|
return NotImplemented
|
|
quantizer.activation_post_process_indexes[node.name] += 1
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=True))
|
|
|
|
# handle conv, maybe followed by relu
|
|
# NB: matching order is reversed, that is we match from the bottom of this list to the beginning
|
|
@register_quant_pattern(torch.nn.Conv1d)
|
|
@register_quant_pattern(torch.nn.Conv2d)
|
|
@register_quant_pattern(torch.nn.Conv3d)
|
|
@register_quant_pattern(torch.nn.functional.conv1d)
|
|
@register_quant_pattern(torch.nn.functional.conv2d)
|
|
@register_quant_pattern(torch.nn.functional.conv3d)
|
|
# TODO: add qat.Conv1d
|
|
@register_quant_pattern(torch.nn.qat.Conv2d)
|
|
@register_quant_pattern(torch.nn.qat.Conv3d)
|
|
@register_quant_pattern(torch.nn.intrinsic.ConvReLU1d)
|
|
@register_quant_pattern(torch.nn.intrinsic.ConvReLU2d)
|
|
@register_quant_pattern(torch.nn.intrinsic.ConvReLU3d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn1d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn2d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn3d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU1d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU2d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU3d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU2d)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU3d)
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv1d))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv2d))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv3d))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv1d))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv2d))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv3d))
|
|
# just for error checks
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv3d))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv3d))
|
|
class ConvReluQuantizeHandler(QuantizeHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
self.relu_node = None
|
|
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
|
|
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
|
|
self.relu_node = node
|
|
node = node.args[0] # type: ignore[assignment]
|
|
self.conv_node = node
|
|
if node.op == "call_module":
|
|
self.conv = quantizer.modules[self.conv_node.target]
|
|
elif node.op == "call_function":
|
|
self.conv = node.target
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
# Supported combinations are:
|
|
# quant_type | activation (compute_type) | weight
|
|
# static quint8 qint8
|
|
|
|
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
|
supported_dtypes = [
|
|
(torch.quint8, torch.qint8, None),
|
|
]
|
|
|
|
# TODO: is_reference option for conv module
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
# leave the op unquantized if the dtype combination is not supported
|
|
if dtypes not in supported_dtypes:
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by Conv "
|
|
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
|
|
if self.relu_node:
|
|
conv_out = quantizer.quantized_graph.node_copy(self.conv_node, load_arg(quantized=False))
|
|
relu_args = [conv_out]
|
|
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
|
return quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
|
else:
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
|
|
activation_int8_quantized = activation_is_int8_quantized(qconfig)
|
|
|
|
if self.conv_node.op == 'call_module':
|
|
# note that relu should already be fused into conv module in the fusion step
|
|
assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \
|
|
'please make sure to run fusion before prepare'
|
|
if convert_custom_config_dict is None:
|
|
convert_custom_config_dict = {}
|
|
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
|
|
# 1. attach activation post process to module
|
|
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
|
self.conv.activation_post_process = \
|
|
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
|
quantizer.activation_post_process_indexes[node.name] += 1
|
|
# 2. select quantized class
|
|
qconv_cls = get_static_quant_module_class(
|
|
type(self.conv), additional_static_quant_mapping)
|
|
quantized = qconv_cls.from_float(self.conv)
|
|
parent_name, name = _parent_name(self.conv_node.target)
|
|
setattr(quantizer.modules[parent_name], name, quantized)
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_module',
|
|
self.conv_node.target,
|
|
(load_arg(quantized=True)(self.conv_node.args[0]),),
|
|
{})
|
|
else: # call_function
|
|
assert self.conv_node.op == "call_function"
|
|
if is_reference:
|
|
args = load_arg(quantized=[0, 1])(self.conv_node.args)
|
|
args = load_arg(quantized=False)(self.conv_node.args)
|
|
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
|
|
op_out = quantizer.quantized_graph.create_node(
|
|
"call_function", self.conv, args, kwargs)
|
|
if self.relu_node:
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
|
op_out = quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
|
|
|
if activation_int8_quantized:
|
|
root_module = quantizer.modules['']
|
|
act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name
|
|
act_post_process_node = self.relu_node if self.relu_node else self.conv_node
|
|
cur_idx = quantizer.activation_post_process_indexes[act_post_process_name]
|
|
activation_post_process = \
|
|
quantizer.modules[quantizer.activation_post_process_map[act_post_process_name][cur_idx]]
|
|
quantizer.activation_post_process_indexes[act_post_process_name] += 1
|
|
return quantize_node(
|
|
quantizer, op_out, activation_post_process,
|
|
act_post_process_node, is_input=False)
|
|
else:
|
|
# output for dynamically quantized conv op is not quantized
|
|
return op_out
|
|
else:
|
|
assert len(self.conv_node.args) >= 7, \
|
|
"only conv2d calls with all arguments specified is supported right now in is_reference=False option"
|
|
args = load_arg(quantized=[0, 1])(self.conv_node.args)
|
|
# pack weight
|
|
weight = load_arg(quantized=True)(self.conv_node.args[1])
|
|
other_args = load_arg(quantized=False)(self.conv_node.args[2:])
|
|
bias, stride, padding, dilation, groups = other_args
|
|
if self.conv == torch.nn.functional.conv1d:
|
|
# F.conv1d can take `int` as well as `list[int]` for stride,
|
|
# padding, dilation, but the prepack op cannot. Convert
|
|
# these to lists if needed.
|
|
stride = [stride] if isinstance(stride, int) else stride
|
|
padding = [padding] if isinstance(padding, int) else padding
|
|
dilation = [dilation] if isinstance(dilation, int) else dilation
|
|
prepack_args = (weight, bias, stride, padding, dilation, groups)
|
|
prepack_op = get_qconv_prepack_op(self.conv)
|
|
packed_weight = quantizer.quantized_graph.create_node(
|
|
"call_function", prepack_op, prepack_args, {})
|
|
assert activation_int8_quantized, \
|
|
"currently only static quantization is supported for conv"
|
|
# construct conv input
|
|
if activation_int8_quantized:
|
|
qconv_op = get_qconv_op(self.conv, self.relu_node is not None)
|
|
conv_input = load_arg(quantized=True)(self.conv_node.args[0])
|
|
act_post_process_name = self.relu_node.name if self.relu_node else self.conv_node.name
|
|
cur_idx = quantizer.activation_post_process_indexes[act_post_process_name]
|
|
activation_post_process = \
|
|
quantizer.modules[quantizer.activation_post_process_map[act_post_process_name][cur_idx]]
|
|
quantizer.activation_post_process_indexes[act_post_process_name] += 1
|
|
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
|
|
scale_node, zero_point_node = create_qparam_nodes(quantizer, self.conv_node.name, scale, zero_point)
|
|
qconv_args = (conv_input, packed_weight, scale_node, zero_point_node)
|
|
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
|
|
op = quantizer.quantized_graph.create_node(
|
|
'call_function', qconv_op, qconv_args, kwargs)
|
|
# Store the name of the fused op to get the path of node after fusion as well.
|
|
# TODO: may need to change the key to Node regenerate the map in each transformation,
|
|
# since we might not be able to rely on the name
|
|
quantizer.node_name_to_scope[op.name] = quantizer.node_name_to_scope[self.conv_node.name]
|
|
return op
|
|
else:
|
|
# conv2d_dyanmic branch
|
|
raise Exception("Only static quant is supported for conv")
|
|
|
|
|
|
# handle linear, maybe followed by relu
|
|
@register_quant_pattern(torch.nn.Linear)
|
|
@register_quant_pattern(torch.nn.functional.linear)
|
|
@register_quant_pattern(torch.nn.qat.Linear)
|
|
@register_quant_pattern(torch.nn.intrinsic.LinearReLU)
|
|
@register_quant_pattern(torch.nn.intrinsic.qat.LinearReLU)
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.linear))
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.linear))
|
|
# for error checks
|
|
@register_quant_pattern((torch.nn.ReLU, torch.nn.Linear))
|
|
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Linear))
|
|
class LinearReLUQuantizeHandler(QuantizeHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
self.relu_node = None
|
|
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
|
|
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
|
|
self.relu_node = node
|
|
node = node.args[0] # type: ignore[assignment]
|
|
self.linear_node = node
|
|
if node.op == 'call_module':
|
|
self.linear = quantizer.modules[self.linear_node.target]
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
# Supported combinations are:
|
|
# quant_type | activation (compute_type) | weight
|
|
# static quint8 qint8
|
|
# dynamic float32 (quint8) qint8
|
|
# weight_only float32 float16
|
|
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
|
supported_dtypes = [
|
|
(torch.quint8, torch.qint8, None),
|
|
(torch.float32, torch.qint8, torch.quint8),
|
|
(torch.float32, torch.float16, None),
|
|
# static float16 quantization
|
|
(torch.float16, torch.float16, None),
|
|
]
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
# leave the op unquantized if the dtype combination is not supported
|
|
if dtypes not in supported_dtypes:
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by Linear "
|
|
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
|
|
if self.relu_node:
|
|
op_out = quantizer.quantized_graph.node_copy(self.linear_node, load_arg(quantized=False))
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
|
return quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
|
else:
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
activation_int8_quantized = activation_is_int8_quantized(qconfig)
|
|
weight_dtype = dtypes[1]
|
|
# TODO: reference_model option for linear module
|
|
if self.linear_node.op == 'call_module':
|
|
# note that relu should already be fused into conv module in the fusion step
|
|
assert self.relu_node is None, 'linear module and relu fusion is not executed, ' \
|
|
'please make sure to run fusion before prepare'
|
|
# 1. attach output activation post process to linear module
|
|
if node.name in quantizer.activation_post_process_map:
|
|
# this is the static quantization case
|
|
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
|
output_activation_post_process = \
|
|
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
|
quantizer.activation_post_process_indexes[node.name] += 1
|
|
else:
|
|
output_activation_post_process = None
|
|
|
|
if output_activation_post_process:
|
|
self.linear.activation_post_process = output_activation_post_process
|
|
|
|
# 2. select corresponding quantized linear class for the float linear class
|
|
if type(self.linear) in [torch.nn.Linear, torch.nn.qat.Linear]:
|
|
qlinear = nnq.Linear if activation_int8_quantized else nnqd.Linear
|
|
elif type(self.linear) in [torch.nn.intrinsic.LinearReLU, torch.nn.intrinsic.qat.LinearReLU]:
|
|
assert activation_int8_quantized, \
|
|
'Only int8 static quantization is supported for LinearReLU'
|
|
qlinear = torch.nn.intrinsic.quantized.LinearReLU
|
|
else:
|
|
raise Exception("unhandled linear type:", type(self.linear))
|
|
quantized = qlinear.from_float(self.linear)
|
|
parent_name, name = _parent_name(self.linear_node.target)
|
|
setattr(quantizer.modules[parent_name], name, quantized)
|
|
# activation needs to be quantized for static quantization
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_module',
|
|
self.linear_node.target,
|
|
(load_arg(quantized=activation_int8_quantized)(self.linear_node.args[0]),), {})
|
|
else: # call_function
|
|
assert self.linear_node.op == 'call_function'
|
|
if is_reference:
|
|
quantized_input_idxs = []
|
|
if activation_int8_quantized:
|
|
quantized_input_idxs.append(0)
|
|
if weight_is_statically_quantized(qconfig):
|
|
quantized_input_idxs.append(1)
|
|
args = load_arg(quantized=quantized_input_idxs)(self.linear_node.args)
|
|
args = load_arg(quantized=False)(self.linear_node.args)
|
|
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
|
|
op_out = quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.linear, args, kwargs)
|
|
if self.relu_node:
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
|
op_out = quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
|
|
|
if activation_int8_quantized:
|
|
# quantize output for statically quantized linear op
|
|
root_module = quantizer.modules['']
|
|
act_post_process_name = self.relu_node.name if self.relu_node else self.linear_node.name
|
|
act_post_process_node = self.relu_node if self.relu_node else self.linear_node
|
|
cur_idx = quantizer.activation_post_process_indexes[act_post_process_name]
|
|
activation_post_process = \
|
|
quantizer.modules[quantizer.activation_post_process_map[act_post_process_name][cur_idx]]
|
|
quantizer.activation_post_process_indexes[act_post_process_name] += 1
|
|
return quantize_node(
|
|
quantizer,
|
|
op_out,
|
|
activation_post_process,
|
|
act_post_process_node,
|
|
is_input=False)
|
|
else:
|
|
# output for dynamically quantized linear op is not quantized
|
|
return op_out
|
|
else: # non-reference option
|
|
# prepacking weights for static int8 quant and dynamic quant
|
|
if dtypes != (torch.float16, torch.float16, None):
|
|
# linear args
|
|
# (x, weight, bias, ...)
|
|
weight_quantized = weight_is_statically_quantized(qconfig)
|
|
linear_weight = load_arg(quantized=weight_quantized)(self.linear_node.args[1])
|
|
|
|
# get other arguments
|
|
kwargs = {**load_arg(quantized=False)(self.linear_node.kwargs)}
|
|
# pack weight
|
|
bias = None
|
|
# all args after bias, including bias
|
|
other_args = load_arg(quantized=False)(self.linear_node.args[2:])
|
|
if len(self.linear_node.args) > 2:
|
|
bias = load_arg(quantized=False)(self.linear_node.args[2])
|
|
other_args = other_args[1:] # remove the bias argument
|
|
else:
|
|
assert 'bias' in kwargs, \
|
|
'expect bias provided as a keyword argument when it is not a positional argument'
|
|
bias = kwargs['bias']
|
|
kwargs.pop('bias')
|
|
prepack_args = (linear_weight, bias)
|
|
prepack_op = get_linear_prepack_op_for_dtype(weight_dtype)
|
|
packed_weight = quantizer.quantized_graph.create_node(
|
|
'call_function', prepack_op, prepack_args, {})
|
|
# construct linear input
|
|
if activation_int8_quantized:
|
|
qlinear_op = torch.ops.quantized.linear_relu if self.relu_node else torch.ops.quantized.linear
|
|
linear_input = load_arg(quantized=True)(self.linear_node.args[0])
|
|
act_post_process_name = self.relu_node.name if self.relu_node else self.linear_node.name
|
|
cur_idx = quantizer.activation_post_process_indexes[act_post_process_name]
|
|
activation_post_process = \
|
|
quantizer.modules[quantizer.activation_post_process_map[act_post_process_name][cur_idx]]
|
|
quantizer.activation_post_process_indexes[act_post_process_name] += 1
|
|
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
|
|
|
|
scale_node, zero_point_node = create_qparam_nodes(quantizer, self.linear_node.name, scale, zero_point)
|
|
|
|
qlinear_args = (linear_input, packed_weight, scale_node, zero_point_node)
|
|
op = quantizer.quantized_graph.create_node(
|
|
"call_function", qlinear_op, qlinear_args, kwargs)
|
|
# Store the name of the fused op to get the path of node after fusion as well.
|
|
# TODO: may need to change the key to Node regenerate the map in each transformation,
|
|
# since we might not be able to rely on the name
|
|
quantizer.node_name_to_scope[op.name] = quantizer.node_name_to_scope[self.linear_node.name]
|
|
return op
|
|
elif dtypes in [(torch.float32, torch.qint8, torch.quint8),
|
|
(torch.float32, torch.float16, None)]:
|
|
# choose linear dynamic or linear dynamic fp16 op based on weight dtype
|
|
qlinear_op = torch.ops.quantized.linear_dynamic \
|
|
if weight_dtype == torch.qint8 \
|
|
else torch.ops.quantized.linear_dynamic_fp16
|
|
linear_input = load_arg(quantized=False)(self.linear_node.args[0])
|
|
qlinear_args = (linear_input, packed_weight) # type: ignore[assignment]
|
|
op_out = quantizer.quantized_graph.create_node(
|
|
"call_function", qlinear_op, qlinear_args, kwargs)
|
|
# Store the name of the dynamic op to get the path of node after replacement as well.
|
|
# TODO: may need to change the key to Node regenerate the map in each transformation,
|
|
# since we might not be able to rely on the name
|
|
quantizer.node_name_to_scope[op_out.name] = quantizer.node_name_to_scope[self.linear_node.name]
|
|
if self.relu_node:
|
|
op_out = quantizer.quantized_graph.create_node("call_function", torch.nn.functional.relu, (op_out,), {})
|
|
return op_out
|
|
else:
|
|
assert dtypes == (torch.float16, torch.float16, None)
|
|
# TODO (refactor) this is duplicated, maybe have a helper function
|
|
if self.relu_node:
|
|
op_out = quantizer.quantized_graph.node_copy(self.linear_node, load_arg(quantized=False))
|
|
relu_args = [op_out]
|
|
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
|
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
|
return quantizer.quantized_graph.create_node(
|
|
"call_function", torch.nn.functional.relu, tuple(relu_args), relu_kwargs)
|
|
else:
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
|
|
@register_quant_pattern(torch.nn.BatchNorm2d)
|
|
@register_quant_pattern(torch.nn.BatchNorm3d)
|
|
@register_quant_pattern(torch.nn.intrinsic.BNReLU2d)
|
|
@register_quant_pattern(torch.nn.intrinsic.BNReLU3d)
|
|
class BatchNormQuantizeHandler(QuantizeHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
assert node.op == 'call_module'
|
|
self.bn_node = node
|
|
self.bn = quantizer.modules[self.bn_node.target]
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
if convert_custom_config_dict is None:
|
|
convert_custom_config_dict = {}
|
|
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
|
|
# 1. attach activation post process to module
|
|
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
|
self.bn.activation_post_process = \
|
|
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
|
quantizer.activation_post_process_indexes[node.name] += 1
|
|
qbn_cls = get_static_quant_module_class(type(self.bn), additional_static_quant_mapping)
|
|
quantized = qbn_cls.from_float(self.bn)
|
|
parent_name, name = _parent_name(self.bn_node.target)
|
|
setattr(quantizer.modules[parent_name], name, quantized)
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_module',
|
|
self.bn_node.target,
|
|
load_arg(quantized=[0])(self.bn_node.args),
|
|
load_arg(quantized=False)(self.bn_node.kwargs))
|
|
|
|
@register_quant_pattern(torch.nn.Embedding)
|
|
@register_quant_pattern(torch.nn.EmbeddingBag)
|
|
class EmbeddingQuantizeHandler(QuantizeHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
|
|
def input_output_observed(self) -> bool:
|
|
return False
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
# Supported combinations are:
|
|
# quant_type | activation | weight | activation_compute_type
|
|
# weight_only | float32 | quint8 | None
|
|
# weight_only | float32 | quint4x2 | None
|
|
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
|
supported_dtypes = [
|
|
(torch.float32, torch.quint8, None),
|
|
(torch.float32, torch.quint4x2, None),
|
|
]
|
|
assert node.op == 'call_module'
|
|
emb_node = node
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
# leave the op unquantized if the dtype combination is not supported
|
|
if dtypes not in supported_dtypes:
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by Embedding/EmbeddingBag, "
|
|
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
emb = quantizer.modules[emb_node.target]
|
|
qemb = get_static_quant_module_class(type(emb))
|
|
quantized = qemb.from_float(emb)
|
|
parent_name, name = _parent_name(emb_node.target)
|
|
setattr(quantizer.modules[parent_name], name, quantized)
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_module',
|
|
emb_node.target,
|
|
load_arg(quantized=False)(emb_node.args),
|
|
load_arg(quantized=False)(emb_node.kwargs))
|
|
|
|
# TODO (maybe): merge with embedding quantize handler
|
|
@register_quant_pattern(torch.nn.GRUCell)
|
|
@register_quant_pattern(torch.nn.LSTMCell)
|
|
@register_quant_pattern(torch.nn.RNNCell)
|
|
@register_quant_pattern(torch.nn.LSTM)
|
|
class RNNDynamicQuantizeHandler(QuantizeHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
|
|
def input_output_observed(self) -> bool:
|
|
return False
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
# Supported combinations are:
|
|
# quant_type | activation | weight | activation_compute_type
|
|
# dynamic | float32 | qint8 | quint8
|
|
# dynamic | float32 | float16 | None
|
|
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
|
supported_dtypes = [
|
|
(torch.float32, torch.qint8, torch.quint8),
|
|
(torch.float32, torch.float16, None),
|
|
]
|
|
assert node.op == 'call_module'
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
# leave the op unquantized if the dtype combination is not supported
|
|
if dtypes not in supported_dtypes:
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by Embedding/EmbeddingBag, "
|
|
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
module = quantizer.modules[node.target]
|
|
qmodule_cls = get_dynamic_quant_module_class(type(module))
|
|
qmodule = qmodule_cls.from_float(module)
|
|
parent_name, name = _parent_name(node.target)
|
|
setattr(quantizer.modules[parent_name], name, qmodule)
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_module',
|
|
node.target,
|
|
load_arg(quantized=False)(node.args),
|
|
load_arg(quantized=False)(node.kwargs))
|
|
|
|
ARGS_TO_SKIP = {
|
|
torch._ops.ops.quantized.hardswish: ['inplace'],
|
|
torch._ops.ops.quantized.instance_norm:
|
|
['running_mean', 'running_var', 'use_input_stats', 'momentum'],
|
|
}
|
|
@register_quant_pattern(torch.nn.ConvTranspose1d)
|
|
@register_quant_pattern(torch.nn.ConvTranspose2d)
|
|
@register_quant_pattern(torch.nn.ELU)
|
|
@register_quant_pattern(torch.nn.LeakyReLU)
|
|
@register_quant_pattern(torch.nn.Hardswish)
|
|
@register_quant_pattern(torch.nn.InstanceNorm1d)
|
|
@register_quant_pattern(torch.nn.InstanceNorm2d)
|
|
@register_quant_pattern(torch.nn.InstanceNorm3d)
|
|
@register_quant_pattern(torch.nn.LayerNorm)
|
|
@register_quant_pattern(torch.nn.SiLU)
|
|
# we currently only support reference patterns for these ops so they have been removed
|
|
# until they receive a proper fp16 kernel. To use the reference pattern, use a custom qconfig
|
|
# @register_quant_pattern(torch.nn.GELU)
|
|
# @register_quant_pattern(torch.nn.Softmax)
|
|
@register_quant_pattern(torch.nn.functional.hardswish)
|
|
@register_quant_pattern(torch.nn.functional.instance_norm)
|
|
@register_quant_pattern(torch.nn.functional.layer_norm)
|
|
@register_quant_pattern(torch.nn.functional.leaky_relu)
|
|
@register_quant_pattern(torch.nn.functional.silu)
|
|
# we currently only support reference patterns for these ops so they have been removed
|
|
# until they receive a proper fp16 kernel. To use the reference pattern, use a custom qconfig
|
|
# @register_quant_pattern(torch.nn.functional.gelu)
|
|
# @register_quant_pattern(torch.nn.functional.softmax)
|
|
class DefaultNodeQuantizeHandler(QuantizeHandler):
|
|
''' Common quantized op, first input and first output will be quantized
|
|
'''
|
|
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
if node.op == "call_function" or node.op == "call_method":
|
|
self.op = node.target
|
|
elif node.op == "call_module":
|
|
self.op = type(quantizer.modules[node.target])
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
if not self.all_node_args_are_tensors:
|
|
return NotImplemented
|
|
assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \
|
|
'call_function are handled in DefaultNode'
|
|
if convert_custom_config_dict is None:
|
|
convert_custom_config_dict = {}
|
|
additional_static_quant_mapping = convert_custom_config_dict.get("static", {})
|
|
|
|
all_dtypes = [
|
|
(torch.quint8, torch.qint8, None),
|
|
(torch.float16, torch.float16, None)
|
|
]
|
|
int8_dtypes = [
|
|
(torch.quint8, torch.qint8, None)
|
|
]
|
|
fp16_dtypes = [
|
|
(torch.float16, torch.float16, None)
|
|
]
|
|
supported_dtypes = {
|
|
torch.nn.ConvTranspose1d: int8_dtypes,
|
|
torch.nn.ConvTranspose2d: int8_dtypes,
|
|
torch.nn.ELU: int8_dtypes,
|
|
torch.nn.LeakyReLU: int8_dtypes,
|
|
torch.nn.Hardswish: int8_dtypes,
|
|
torch.nn.InstanceNorm1d: int8_dtypes,
|
|
torch.nn.InstanceNorm2d: int8_dtypes,
|
|
torch.nn.InstanceNorm3d: int8_dtypes,
|
|
torch.nn.LayerNorm: all_dtypes,
|
|
torch.nn.SiLU: fp16_dtypes,
|
|
torch.nn.GELU: int8_dtypes,
|
|
torch.nn.Softmax: int8_dtypes,
|
|
torch.nn.functional.hardswish: int8_dtypes,
|
|
torch.nn.functional.instance_norm: int8_dtypes,
|
|
torch.nn.functional.layer_norm: all_dtypes,
|
|
torch.nn.functional.leaky_relu: int8_dtypes,
|
|
torch.nn.functional.silu: fp16_dtypes,
|
|
torch.nn.functional.gelu: int8_dtypes,
|
|
torch.nn.functional.softmax: int8_dtypes,
|
|
}
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
if dtypes not in supported_dtypes[self.op]:
|
|
warnings.warn(
|
|
"dtype combination: {} is not "
|
|
"supported by {} "
|
|
"supported dtype combinations are: {}".format(dtypes, self.op, supported_dtypes[self.op]))
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
# TODO: make helper functions for (torch.quint8, torch.qint8, None)
|
|
if not is_reference:
|
|
if dtypes in [(torch.quint8, torch.qint8, None)]:
|
|
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
|
activation_post_process = \
|
|
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
|
quantizer.activation_post_process_indexes[node.name] += 1
|
|
if node.op == 'call_module':
|
|
module = quantizer.modules[node.target]
|
|
module.activation_post_process = activation_post_process
|
|
quantized_module_cls = get_static_quant_module_class(
|
|
type(module), additional_static_quant_mapping)
|
|
quantized_module = quantized_module_cls.from_float(module)
|
|
parent_name, name = _parent_name(node.target)
|
|
setattr(quantizer.modules[parent_name], name, quantized_module)
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_module',
|
|
node.target,
|
|
load_arg(quantized=[0])(node.args),
|
|
load_arg(quantized=False)(node.kwargs))
|
|
else:
|
|
assert node.op == "call_function"
|
|
# call_function
|
|
scale, zero_point = activation_post_process.calculate_qparams()
|
|
scale = float(scale)
|
|
zero_point = int(zero_point)
|
|
|
|
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
|
|
|
|
assert not isinstance(node.target, str), "Expecting node.target for "
|
|
"call_function to be a function instead of a string"
|
|
quantized_op = get_quantized_operator(node.target)
|
|
args = load_arg(quantized=[0])(node.args)
|
|
kwargs = {**load_arg(quantized=False)(node.kwargs), "output_scale": scale_arg,
|
|
"output_zero_point": zero_point_arg}
|
|
if quantized_op in ARGS_TO_SKIP:
|
|
args_to_skip = ARGS_TO_SKIP[quantized_op]
|
|
for arg in args_to_skip:
|
|
if arg in kwargs:
|
|
kwargs.pop(arg)
|
|
return quantizer.quantized_graph.create_node(
|
|
"call_function", quantized_op, args, kwargs)
|
|
else:
|
|
assert dtypes in [(torch.float16, torch.float16, None)]
|
|
# Generally fp16 kernels don't exist for fp16 ops
|
|
warnings.warn(
|
|
"Only reference patterns are currently supported for {dtype} dtype with {op} op"
|
|
"".format(dtype=dtypes, op=self.op))
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
else:
|
|
assert is_reference
|
|
if dtypes in [(torch.quint8, torch.qint8, None)]:
|
|
load_arg(quantized=[0])(node.args)
|
|
args = load_arg(quantized=False)(node.args)
|
|
kwargs = load_arg(quantized=False)(node.kwargs)
|
|
op_out = quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
|
activation_post_process = \
|
|
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
|
quantizer.activation_post_process_indexes[node.name] += 1
|
|
return quantize_node(
|
|
quantizer, op_out, activation_post_process,
|
|
node, is_input=False)
|
|
else:
|
|
assert dtypes in [(torch.float16, torch.float16, None)]
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
|
|
|
|
# TODO: elu is using scale/zero_point instead of output_scale, output_zero_point
|
|
@register_quant_pattern(torch.nn.functional.elu)
|
|
class ELUQuantizeHandler(QuantizeHandler):
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
|
activation_post_process = \
|
|
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
|
quantizer.activation_post_process_indexes[node.name] += 1
|
|
scale, zero_point = activation_post_process.calculate_qparams()
|
|
scale = float(scale)
|
|
zero_point = int(zero_point)
|
|
|
|
scale_arg, zero_point_arg = create_qparam_nodes(quantizer, node.name, scale, zero_point)
|
|
|
|
quantized_op = get_quantized_operator(node.target)
|
|
args = load_arg(quantized=[0])(node.args)
|
|
kwargs = {**load_arg(quantized=False)(node.kwargs), 'output_scale': scale_arg, 'output_zero_point': zero_point_arg}
|
|
kwargs.pop('inplace')
|
|
return quantizer.quantized_graph.create_node(
|
|
'call_function', quantized_op, args, kwargs)
|
|
|
|
@register_quant_pattern(torch.nn.Hardsigmoid, default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern(torch.nn.functional.hardsigmoid, default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern('hardsigmoid', default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern('hardsigmoid_', default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern(torch.nn.Sigmoid, default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern(torch.sigmoid, default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern('sigmoid', default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern('sigmoid_', default_affine_fixed_qparams_fake_quant)
|
|
@register_quant_pattern(torch.nn.Tanh, default_symmetric_fixed_qparams_fake_quant)
|
|
@register_quant_pattern(torch.tanh, default_symmetric_fixed_qparams_fake_quant)
|
|
@register_quant_pattern('tanh', default_symmetric_fixed_qparams_fake_quant)
|
|
@register_quant_pattern('tanh_', default_symmetric_fixed_qparams_fake_quant)
|
|
class FixedQParamsOpQuantizeHandler(QuantizeHandler):
|
|
def __init__(self, quantizer: QuantizerCls, node: Node):
|
|
super().__init__(quantizer, node)
|
|
self.node = node
|
|
|
|
def should_insert_observer_for_output(
|
|
self,
|
|
qconfig: Any,
|
|
model_is_training: bool,
|
|
) -> bool:
|
|
if model_is_training:
|
|
# in QAT, always insert fake_quants
|
|
return True
|
|
else:
|
|
# in PTQ, only insert observers when emulating fp16
|
|
return activation_dtype(qconfig) == torch.float16
|
|
|
|
def should_mark_output_observed_from_input_observed_status(
|
|
self,
|
|
observed_node_names_set: Set[str],
|
|
) -> bool:
|
|
# For these ops if input is observed, output is also observed
|
|
def is_observed(input_arg):
|
|
if isinstance(input_arg, Node):
|
|
return input_arg.name in observed_node_names_set
|
|
elif isinstance(input_arg, list):
|
|
return all(map(is_observed, input_arg))
|
|
return is_observed(self.node.args[0])
|
|
|
|
def should_mark_output_quantized_from_input_quantized_status(
|
|
self,
|
|
) -> bool:
|
|
return True
|
|
|
|
# some qhandlers override the activations constructor
|
|
def get_activation_ctr(self, qconfig, pattern) -> Optional[Callable]:
|
|
if activation_dtype(qconfig) == torch.float16:
|
|
return qconfig.activation
|
|
else:
|
|
return get_default_output_activation_post_process_map().get(
|
|
pattern, None)
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
dtypes = get_qconfig_dtypes(qconfig)
|
|
if dtypes == (torch.float16, torch.float16, None):
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
|
else:
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
|
|
# these ops have quantized equivalents that do not need any extra information
|
|
@register_quant_pattern(torch.nn.AdaptiveAvgPool1d)
|
|
@register_quant_pattern(torch.nn.AdaptiveAvgPool2d)
|
|
@register_quant_pattern(torch.nn.AdaptiveAvgPool3d)
|
|
@register_quant_pattern(torch.nn.AvgPool1d)
|
|
@register_quant_pattern(torch.nn.AvgPool2d)
|
|
@register_quant_pattern(torch.nn.AvgPool3d)
|
|
@register_quant_pattern(torch.nn.Dropout)
|
|
@register_quant_pattern(torch.nn.Hardtanh)
|
|
@register_quant_pattern(torch.nn.Identity)
|
|
@register_quant_pattern(torch.nn.MaxPool1d)
|
|
@register_quant_pattern(torch.nn.MaxPool2d)
|
|
@register_quant_pattern(torch.nn.MaxPool3d)
|
|
@register_quant_pattern(torch.nn.ReLU)
|
|
@register_quant_pattern(torch.nn.ReLU6)
|
|
@register_quant_pattern(torch.adaptive_avg_pool1d)
|
|
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool2d)
|
|
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool3d)
|
|
@register_quant_pattern(torch.nn.functional.dropout)
|
|
@register_quant_pattern(torch.nn.functional.hardtanh)
|
|
@register_quant_pattern(torch.nn.functional.hardtanh_)
|
|
@register_quant_pattern(torch.nn.functional.interpolate)
|
|
@register_quant_pattern(torch.nn.functional.max_pool1d)
|
|
@register_quant_pattern(torch.nn.functional.max_pool2d)
|
|
@register_quant_pattern(torch.nn.functional.max_pool3d)
|
|
@register_quant_pattern(torch.nn.functional.relu)
|
|
@register_quant_pattern(torch.nn.functional.relu6)
|
|
@register_quant_pattern(torch.avg_pool1d)
|
|
@register_quant_pattern(torch._C._nn.avg_pool2d)
|
|
@register_quant_pattern(torch._C._nn.avg_pool3d)
|
|
@register_quant_pattern(torch.chunk)
|
|
@register_quant_pattern(torch.clamp)
|
|
@register_quant_pattern(torch.flatten)
|
|
@register_quant_pattern(torch.transpose)
|
|
@register_quant_pattern(torch.max)
|
|
@register_quant_pattern(torch.mean)
|
|
@register_quant_pattern(torch.min)
|
|
@register_quant_pattern(torch.repeat_interleave)
|
|
@register_quant_pattern(torch.sort)
|
|
@register_quant_pattern(torch.squeeze)
|
|
@register_quant_pattern(torch.stack)
|
|
@register_quant_pattern(torch.unsqueeze)
|
|
@register_quant_pattern(operator.floordiv)
|
|
@register_quant_pattern(operator.getitem)
|
|
@register_quant_pattern('chunk')
|
|
@register_quant_pattern('clamp')
|
|
@register_quant_pattern('contiguous')
|
|
@register_quant_pattern('detach')
|
|
@register_quant_pattern('detach_')
|
|
@register_quant_pattern('mean')
|
|
@register_quant_pattern('numel')
|
|
@register_quant_pattern('permute')
|
|
@register_quant_pattern('relu')
|
|
@register_quant_pattern('relu_')
|
|
@register_quant_pattern('repeat')
|
|
@register_quant_pattern('repeat_interleave')
|
|
@register_quant_pattern('reshape')
|
|
@register_quant_pattern('resize_')
|
|
@register_quant_pattern('shape')
|
|
@register_quant_pattern('size')
|
|
@register_quant_pattern('squeeze')
|
|
@register_quant_pattern('squeeze_')
|
|
@register_quant_pattern('transpose')
|
|
@register_quant_pattern('unsqueeze')
|
|
@register_quant_pattern('unsqueeze_')
|
|
@register_quant_pattern('view')
|
|
class CopyNodeQuantizeHandler(QuantizeHandler):
|
|
def should_mark_output_quantized_from_input_quantized_status(
|
|
self,
|
|
) -> bool:
|
|
return True
|
|
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
# Default quantization handler, used for quantization of input and output
|
|
# of quantizable objects (e.g. modules and functionals)
|
|
class DefaultQuantizeHandler(QuantizeHandler):
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
assert self.all_node_args_are_tensors
|
|
root_module = quantizer.modules['']
|
|
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
|
activation_post_process = \
|
|
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
|
quantizer.activation_post_process_indexes[node.name] += 1
|
|
return quantize_node(
|
|
quantizer,
|
|
node, activation_post_process, node, is_input=False)
|
|
|
|
class CustomModuleQuantizeHandler(QuantizeHandler):
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
""" Convert a float custom module to quantized custom module
|
|
"""
|
|
assert node.op == 'call_module'
|
|
assert convert_custom_config_dict is not None
|
|
custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", None)
|
|
assert custom_module_class_mapping is not None
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
observed_custom_module = quantizer.modules[node.target]
|
|
if activation_is_statically_quantized(qconfig):
|
|
assert node.name in quantizer.activation_post_process_map
|
|
cur_idx = quantizer.activation_post_process_indexes[node.name]
|
|
observed_custom_module.activation_post_process = \
|
|
quantizer.modules[quantizer.activation_post_process_map[node.name][cur_idx]]
|
|
quantizer.activation_post_process_indexes[node.name] += 1
|
|
quantized_custom_module_class = get_swapped_custom_module_class(
|
|
observed_custom_module, custom_module_class_mapping, qconfig)
|
|
quantized_custom_module = \
|
|
quantized_custom_module_class.from_observed(observed_custom_module)
|
|
parent_name, name = _parent_name(node.target)
|
|
setattr(quantizer.modules[parent_name], name, quantized_custom_module)
|
|
# hardcoded the qunatized input to be None (take whatever is in the environemnt),
|
|
# we can extend this
|
|
# if there is a need, e.g. get the indexes of quantized inputs from some
|
|
# module attribute like module._QUANTIZED_INPUT_INDEXES
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
|
|
|
|
class StandaloneModuleQuantizeHandler(QuantizeHandler):
|
|
""" Converts an observed standalone module to quantized standalone module
|
|
by calling convert_fx on the observed standalone module.
|
|
"""
|
|
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
|
is_reference: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None) -> Node:
|
|
assert node.op == 'call_module'
|
|
qconfig = quantizer.qconfig_map[node.name]
|
|
convert = torch.quantization.quantize_fx._convert_standalone_module_fx # type: ignore[attr-defined]
|
|
observed_standalone_module = quantizer.modules[node.target]
|
|
input_quantized_idxs = observed_standalone_module._standalone_module_input_quantized_idxs.tolist()
|
|
quantized_standalone_module = convert(observed_standalone_module, is_reference=is_reference)
|
|
parent_name, name = _parent_name(node.target)
|
|
# update the modules dict
|
|
setattr(quantizer.modules[parent_name], name, quantized_standalone_module)
|
|
quantizer.modules[node.target] = quantized_standalone_module
|
|
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=input_quantized_idxs))
|