mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49239 Context: the existing implementation of `quantized_input_idxs` is convert-only. Therefore, observers are inserted between the input and the first quantized node. This is a problem during QAT, because the initial input is a fake_quant, and it starts with scale=1 and zp=0. This does not match the quantization parameters of the graph input, which can lead to incorrect numerics. Fix: do not insert observer for a quantized input. Test Plan: ``` python test/test_quantization.py TestQuantizeFx ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D25499486 fbshipit-source-id: 303b49cc9d95a9fd06fef3b0859c08be34e19d8a
995 lines
43 KiB
Python
995 lines
43 KiB
Python
import torch
|
|
from torch.fx import ( # type: ignore
|
|
GraphModule,
|
|
Proxy,
|
|
map_arg
|
|
)
|
|
|
|
from torch.fx.graph import (
|
|
Graph,
|
|
Node,
|
|
)
|
|
|
|
from torch.quantization import (
|
|
propagate_qconfig_,
|
|
convert,
|
|
)
|
|
|
|
from ..quantization_mappings import (
|
|
get_default_qat_module_mappings,
|
|
)
|
|
|
|
from ..quantize import (
|
|
_remove_qconfig,
|
|
is_activation_post_process
|
|
)
|
|
|
|
from ..utils import (
|
|
get_combined_dict,
|
|
get_swapped_custom_module_class,
|
|
activation_is_statically_quantized,
|
|
)
|
|
|
|
from .pattern_utils import (
|
|
is_match,
|
|
get_default_quant_patterns,
|
|
get_default_output_activation_post_process_map,
|
|
input_output_observed,
|
|
Pattern,
|
|
)
|
|
|
|
from .observed_module import (
|
|
mark_observed_module,
|
|
is_observed_module,
|
|
mark_observed_standalone_module,
|
|
is_observed_standalone_module,
|
|
)
|
|
|
|
from .quantization_patterns import *
|
|
|
|
from .utils import (
|
|
_parent_name,
|
|
quantize_node,
|
|
get_custom_module_class_keys,
|
|
get_new_attr_name_with_prefix,
|
|
collect_producer_nodes,
|
|
graph_module_from_producer_nodes,
|
|
assert_and_get_unique_device,
|
|
)
|
|
|
|
from .qconfig_utils import *
|
|
|
|
import warnings
|
|
|
|
from typing import Optional, Dict, Any, List, Union, Tuple, Set, Callable
|
|
|
|
# Define helper types
|
|
|
|
QConfigAny = Union[torch.quantization.QConfig,
|
|
torch.quantization.QConfigDynamic, None]
|
|
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
|
|
QConfigAny]
|
|
|
|
# ------------------------
|
|
# Helper Functions
|
|
# ------------------------
|
|
|
|
def insert_observer(
|
|
node: Node, observer: torch.quantization.ObserverBase,
|
|
model: torch.nn.Module,
|
|
activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
|
|
env: Dict[Any, Any], observed_graph: Graph, load_arg: Callable,
|
|
observed_node_names_set: Set[str]):
|
|
"""Insert observer for node by modifying the observed_graph and
|
|
attach observer module to the model
|
|
Args:
|
|
node: Node
|
|
observer: observer/fake_quantize module instance
|
|
"""
|
|
# respect device affinity when adding observers
|
|
model_device = assert_and_get_unique_device(model)
|
|
if model_device:
|
|
observer.to(model_device)
|
|
# add observer module as attribute
|
|
prefix = node.name + '_activation_post_process_'
|
|
get_new_observer_name = get_new_attr_name_with_prefix(prefix)
|
|
observer_name = get_new_observer_name(model)
|
|
setattr(model, observer_name, observer)
|
|
# put observer instance activation_post_process map
|
|
assert activation_post_process_map is not None
|
|
activation_post_process_map[node.name] = observer
|
|
# insert observer call
|
|
env[node.name] = observed_graph.create_node(
|
|
'call_module', observer_name, (load_arg(node),), {})
|
|
observed_node_names_set.add(node.name)
|
|
|
|
def insert_observer_for_special_module(
|
|
quantize_handler: QuantizeHandler, modules: Dict[str, torch.nn.Module],
|
|
prepare_custom_config_dict: Any, qconfig: Any, node: Node):
|
|
""" Insert observer for custom module and standalone module
|
|
Returns: standalone_module_input_idxs: the indexs for inputs that
|
|
needs to be observed by parent module
|
|
"""
|
|
assert modules is not None
|
|
if isinstance(quantize_handler, CustomModuleQuantizeHandler):
|
|
custom_module = modules[node.target] # type: ignore
|
|
custom_module_class_mapping = prepare_custom_config_dict.get(
|
|
"float_to_observed_custom_module_class", {})
|
|
observed_custom_module_class = \
|
|
get_swapped_custom_module_class(
|
|
custom_module, custom_module_class_mapping, qconfig)
|
|
observed_custom_module = \
|
|
observed_custom_module_class.from_float(custom_module)
|
|
parent_name, name = _parent_name(node.target)
|
|
setattr(modules[parent_name], name, observed_custom_module)
|
|
elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
|
|
# observe standalone module
|
|
standalone_module = modules[node.target] # type: ignore
|
|
prepare = \
|
|
torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore
|
|
observed_standalone_module = \
|
|
prepare(standalone_module, {"": qconfig})
|
|
observed_standalone_module.qconfig = qconfig
|
|
observed_standalone_module = mark_observed_standalone_module(
|
|
observed_standalone_module)
|
|
parent_name, name = _parent_name(node.target)
|
|
setattr(modules[parent_name], name,
|
|
observed_standalone_module)
|
|
modules[node.target] = observed_standalone_module # type: ignore
|
|
|
|
def insert_observer_for_output_of_the_node(
|
|
node: Node,
|
|
quantize_handler: QuantizeHandler,
|
|
qconfig: Any,
|
|
modules: Dict[str, torch.nn.Module],
|
|
model: torch.nn.Module,
|
|
pattern: Any,
|
|
activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
|
|
env: Dict[Any, Any],
|
|
observed_graph: Graph,
|
|
load_arg: Callable,
|
|
observed_node_names_set: Set[str],
|
|
matched_nodes: Optional[List[Node]]):
|
|
""" Insert observer/fake_quantize module for output of the observed
|
|
module if needed
|
|
"""
|
|
# don't need to insert observer for output if activation does not
|
|
# need to be statically quantized
|
|
assert modules is not None
|
|
if activation_is_statically_quantized(qconfig):
|
|
if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) \
|
|
and model.training:
|
|
# we only insert fake quantize module in qat
|
|
assert pattern is not None
|
|
activation_post_process_ctr = \
|
|
get_default_output_activation_post_process_map().get(
|
|
pattern, None)
|
|
assert activation_post_process_ctr is not None, \
|
|
"activation_post_process constructor not provided " + \
|
|
"for pattern:" + str(pattern)
|
|
insert_observer(
|
|
node, activation_post_process_ctr(),
|
|
model, activation_post_process_map, env, observed_graph,
|
|
load_arg, observed_node_names_set)
|
|
elif (isinstance(quantize_handler,
|
|
FixedQParamsOpQuantizeHandler) and
|
|
not model.training) or \
|
|
isinstance(quantize_handler, CopyNode):
|
|
# inserting observers for output of observed module, or
|
|
# mark the output as observed
|
|
assert node.op in [
|
|
'call_module',
|
|
'call_function',
|
|
'call_method'], \
|
|
'CopyNode of type ' + node.op + ' is not handled'
|
|
|
|
def is_observed(input_arg):
|
|
if isinstance(input_arg, Node):
|
|
return input_arg.name in observed_node_names_set
|
|
elif isinstance(input_arg, list):
|
|
return all(map(is_observed, input_arg))
|
|
# propagate observed property from input
|
|
if is_observed(node.args[0]):
|
|
observed_node_names_set.add(node.name)
|
|
elif ((isinstance(quantize_handler, Add) or
|
|
isinstance(quantize_handler, Mul)) and
|
|
quantize_handler.num_node_args == 1):
|
|
assert matched_nodes is not None
|
|
input_node = matched_nodes[-1] # first node in the sequence
|
|
|
|
def input_is_observed(arg):
|
|
return (isinstance(arg, Node) and
|
|
arg.name in observed_node_names_set)
|
|
# This is checking if one of the argument of add/mul
|
|
# is an observed node
|
|
# If both of the inputs are number,
|
|
# we will not consider the output to be observed
|
|
if (input_is_observed(input_node.args[0]) or
|
|
input_is_observed(input_node.args[1])):
|
|
observed_node_names_set.add(node.name)
|
|
elif isinstance(quantize_handler,
|
|
StandaloneModuleQuantizeHandler):
|
|
# output is observed in the standalone module
|
|
return
|
|
elif (quantize_handler.all_node_args and
|
|
input_output_observed(quantize_handler)):
|
|
# observer for outputs
|
|
new_observer = qconfig.activation()
|
|
insert_observer(
|
|
node, new_observer, model,
|
|
activation_post_process_map, env, observed_graph,
|
|
load_arg, observed_node_names_set)
|
|
|
|
def insert_observer_for_input_arg_of_observed_node(
|
|
node: Node, observed_node_names_set: Set[str], quants: Dict[str, Any],
|
|
model: torch.nn.Module,
|
|
activation_post_process_map: Dict[str, torch.quantization.ObserverBase],
|
|
env: Dict[str, str], observed_graph: Graph,
|
|
load_arg: Callable):
|
|
if node.name not in observed_node_names_set and node.name in quants:
|
|
_, activation_post_process_ctr = quants[node.name]
|
|
if activation_post_process_ctr is not None:
|
|
insert_observer(
|
|
node, activation_post_process_ctr(),
|
|
model, activation_post_process_map,
|
|
env, observed_graph, load_arg, observed_node_names_set)
|
|
|
|
# A dictionary for querying the weight index for a given op
|
|
WEIGHT_INDEX_DICT = {
|
|
torch.nn.functional.conv2d : [1],
|
|
torch.nn.functional.linear : [1],
|
|
}
|
|
|
|
# weight prepacking ops
|
|
WEIGHT_PREPACK_OPS = {
|
|
torch._ops.ops.quantized.linear_prepack,
|
|
torch._ops.ops.quantized.linear_prepack_fp16,
|
|
torch._ops.ops.quantized.conv2d_prepack,
|
|
}
|
|
|
|
class Quantizer:
|
|
def __init__(self):
|
|
# mapping from matched node to activation_post_process
|
|
# must be filled before convert
|
|
self.activation_post_process_map: Optional[
|
|
Dict[str, torch.quantization.observer.ObserverBase]] = None
|
|
# mapping from node name to qconfig that should be used for that node
|
|
# filled out for a model during _generate_qconfig_map
|
|
self.qconfig_map: Optional[Dict[str, QConfigAny]] = None
|
|
# mapping from fully qualified module name to module instance
|
|
# for example,
|
|
# {
|
|
# '': Model(...),
|
|
# 'linear': Linear(...),
|
|
# 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
|
|
# }
|
|
self.modules: Optional[Dict[str, torch.nn.Module]] = None
|
|
# mapping from a tuple of nodes in reverse order to uninitialized
|
|
# QuantizeHandler subclass. For example,
|
|
# {
|
|
# # match a single node
|
|
# (<class 'torch.nn.modules.conv.Conv3d'>:
|
|
# <class 'torch.quantization.fx.quantize.ConvRelu'>),
|
|
# # match multiple nodes in reverse order
|
|
# ((<function relu at 0x7f766a7360d0>, <built-in function add>):
|
|
# <class 'torch.quantization.fx.quantize.Add'>),
|
|
# }
|
|
self.patterns: Optional[Dict[Pattern, QuantizeHandler]] = None
|
|
self.prepare_custom_config_dict: Dict[str, Any] = {}
|
|
|
|
|
|
def _qat_swap_modules(
|
|
self, root: torch.nn.Module,
|
|
additional_qat_module_mapping: Dict[Callable, Callable]) -> None:
|
|
all_mappings = get_combined_dict(
|
|
get_default_qat_module_mappings(), additional_qat_module_mapping)
|
|
convert(root, mapping=all_mappings, inplace=True, remove_qconfig=False)
|
|
|
|
def _generate_qconfig_map(
|
|
self,
|
|
root: torch.nn.Module,
|
|
input_graph: Graph,
|
|
qconfig_dict: Any) -> None:
|
|
global_qconfig = qconfig_dict.get('', None)
|
|
|
|
self.qconfig_map = dict()
|
|
for node in input_graph.nodes:
|
|
if node.op == 'get_attr':
|
|
module_name, _ = _parent_name(node.target)
|
|
self.qconfig_map[node.name] = get_qconfig(
|
|
self.modules, qconfig_dict, module_name, global_qconfig)
|
|
elif node.op == 'call_function':
|
|
# precedence: [TODO] module_name_qconfig (need scope support
|
|
# from fx)
|
|
# > function_qconfig > global_qconfig
|
|
function_qconfig = get_function_qconfig(
|
|
qconfig_dict, node.target, global_qconfig)
|
|
self.qconfig_map[node.name] = function_qconfig
|
|
elif node.op == 'call_method':
|
|
self_obj = node.args[0]
|
|
# qconfig for call_method should be the same as the `self`
|
|
# object for the call
|
|
if self_obj.name in self.qconfig_map:
|
|
qconfig = self.qconfig_map[self_obj.name]
|
|
else:
|
|
# need scope info for each node to support this
|
|
warnings.warn(
|
|
"Scope info is not yet supported, taking default " +
|
|
"qconfig for value {}".format(node.name))
|
|
qconfig = get_qconfig(
|
|
self.modules, qconfig_dict, '', global_qconfig)
|
|
self.qconfig_map[node.name] = qconfig
|
|
elif node.op == 'call_module':
|
|
module_qconfig = get_qconfig(
|
|
self.modules, qconfig_dict, node.target, global_qconfig)
|
|
# regex is not supported eager mode propagate_qconfig_, we'll
|
|
# need to set the qconfig explicitly here in case regex
|
|
# is used
|
|
assert self.modules is not None
|
|
self.modules[node.target].qconfig = module_qconfig
|
|
self.qconfig_map[node.name] = module_qconfig
|
|
|
|
def _prepare(self, model: GraphModule, qconfig_dict: Any,
|
|
prepare_custom_config_dict: Optional[Dict[str, Any]],
|
|
is_standalone_module: bool) -> GraphModule:
|
|
""" standalone_module means it a submodule that is not inlined in
|
|
parent module, and will be quantized separately as one unit.
|
|
|
|
When we are preparing a standalone module:
|
|
both input and output are observed in prepared standalone module
|
|
Returns:
|
|
model(GraphModule): prepared standalone module
|
|
"""
|
|
if prepare_custom_config_dict is None:
|
|
prepare_custom_config_dict = {}
|
|
self.prepare_custom_config_dict = prepare_custom_config_dict
|
|
|
|
additional_quant_patterns = \
|
|
prepare_custom_config_dict.get("additional_quant_pattern", {})
|
|
self.patterns = get_combined_dict(
|
|
get_default_quant_patterns(), additional_quant_patterns)
|
|
|
|
flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
|
|
# TODO: support regex as well
|
|
propagate_qconfig_(model, flattened_qconfig_dict)
|
|
if model.training:
|
|
additional_qat_module_mapping = prepare_custom_config_dict.get(
|
|
"additional_qat_module_mapping", {})
|
|
self._qat_swap_modules(model, additional_qat_module_mapping)
|
|
|
|
self.modules = dict(model.named_modules())
|
|
|
|
convert_dict_to_ordered_dict(qconfig_dict)
|
|
# map from node name to qconfig, used in _find_matches
|
|
self._generate_qconfig_map(model, model.graph, qconfig_dict)
|
|
|
|
# match the patterns that will get quantized
|
|
standalone_module_names = prepare_custom_config_dict.get(
|
|
"standalone_module_name", None)
|
|
standalone_module_classes = prepare_custom_config_dict.get(
|
|
"standalone_module_class", None)
|
|
custom_module_classes = get_custom_module_class_keys(
|
|
prepare_custom_config_dict, "float_to_observed_custom_module_class")
|
|
assert self.patterns is not None
|
|
matches = self._find_matches(
|
|
model.graph, self.modules, self.patterns, standalone_module_names,
|
|
standalone_module_classes, custom_module_classes)
|
|
|
|
# find _inputs_ to matched nodes that are not quantized, these
|
|
# have to be quantized, which requires measuring stats,
|
|
# initialize an DefaultQuantizeHandler object for each
|
|
quants = self._find_quants(model.graph, matches)
|
|
|
|
self.activation_post_process_map = dict()
|
|
env: Dict[Any, Any] = {}
|
|
observed_graph = Graph()
|
|
observed_node_names_set: Set[str] = set()
|
|
|
|
def load_arg(a):
|
|
return map_arg(a, lambda node: env[node.name])
|
|
|
|
# indexes for the inputs that needs to be observed
|
|
standalone_module_observed_input_idxs: List[int] = []
|
|
graph_inputs = []
|
|
for node in model.graph.nodes:
|
|
if node.op == 'placeholder':
|
|
graph_inputs.append(node.name)
|
|
|
|
get_new_observer_name = get_new_attr_name_with_prefix(
|
|
'activation_post_process_')
|
|
|
|
placeholder_node_seen_cnt = 0
|
|
input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
|
|
"input_quantized_idxs", [])
|
|
|
|
result_node : Optional[Node] = None
|
|
for node in model.graph.nodes:
|
|
if node.op == 'output':
|
|
observed_graph.output(load_arg(node.args[0]))
|
|
result_node = node
|
|
continue
|
|
if node.name in observed_node_names_set:
|
|
continue
|
|
|
|
root_node, matched_nodes, pattern, obj, qconfig = matches.get(
|
|
node.name, (None, None, None, None, None))
|
|
if root_node is None:
|
|
env[node.name] = observed_graph.node_copy(node, load_arg)
|
|
elif root_node is node:
|
|
env[node.name] = observed_graph.node_copy(node, load_arg)
|
|
# index for input of custom module that needs to be observed in
|
|
# parent
|
|
if qconfig is not None:
|
|
assert obj is not None
|
|
insert_observer_for_special_module(
|
|
obj, self.modules, prepare_custom_config_dict, qconfig,
|
|
node)
|
|
insert_observer_for_output_of_the_node(
|
|
node, obj, qconfig, self.modules, model, pattern,
|
|
self.activation_post_process_map, env,
|
|
observed_graph, load_arg, observed_node_names_set,
|
|
matched_nodes)
|
|
else:
|
|
env[node.name] = observed_graph.node_copy(node, load_arg)
|
|
|
|
if node.op == 'placeholder':
|
|
# skip adding observers at the graph input if the input is
|
|
# overriden to be quantized
|
|
cur_placeholder_node_idx = placeholder_node_seen_cnt
|
|
placeholder_node_seen_cnt += 1
|
|
if cur_placeholder_node_idx in input_quantized_idxs:
|
|
continue
|
|
|
|
insert_observer_for_input_arg_of_observed_node(
|
|
node, observed_node_names_set, quants,
|
|
model, self.activation_post_process_map, env,
|
|
observed_graph, load_arg)
|
|
|
|
|
|
model = GraphModule(model, observed_graph)
|
|
self.save_state(model)
|
|
model = mark_observed_module(model)
|
|
return model
|
|
|
|
def save_state(self, observed: GraphModule) -> None:
|
|
observed._activation_post_process_map = \
|
|
self.activation_post_process_map # type: ignore
|
|
observed._patterns = self.patterns # type: ignore
|
|
observed._qconfig_map = self.qconfig_map # type: ignore
|
|
observed._prepare_custom_config_dict = \
|
|
self.prepare_custom_config_dict # type: ignore
|
|
|
|
def restore_state(self, observed: GraphModule) -> None:
|
|
assert is_observed_module(observed), \
|
|
'incoming model must be produced by prepare_fx'
|
|
self.activation_post_process_map = \
|
|
observed._activation_post_process_map # type: ignore
|
|
self.patterns = observed._patterns # type: ignore
|
|
self.qconfig_map = observed._qconfig_map # type: ignore
|
|
self.prepare_custom_config_dict = \
|
|
observed._prepare_custom_config_dict # type: ignore
|
|
|
|
def prepare(self, model: GraphModule, qconfig_dict: Any,
|
|
prepare_custom_config_dict: Dict[str, Any] = None,
|
|
is_standalone_module: bool = False) -> GraphModule:
|
|
return self._prepare(
|
|
model, qconfig_dict, prepare_custom_config_dict,
|
|
is_standalone_module)
|
|
|
|
def _run_weight_observers(self, observed: GraphModule) -> None:
|
|
r''' Extract the subgraph that produces the weight for dynamic quant
|
|
or weight only quant node and run the subgraph to observe the weight.
|
|
Note that the observers of dynamic quant or weight only quant ops are
|
|
run during the convert step.
|
|
'''
|
|
for node in observed.graph.nodes:
|
|
if node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
|
|
for i, node_arg in enumerate(node.args):
|
|
if i in WEIGHT_INDEX_DICT[node.target]:
|
|
# node_arg is weight
|
|
weight_observer_nodes = collect_producer_nodes(node_arg)
|
|
if weight_observer_nodes is not None:
|
|
weight_observer_module = \
|
|
graph_module_from_producer_nodes(
|
|
observed, weight_observer_nodes)
|
|
# run the weight observer
|
|
weight_observer_module()
|
|
return
|
|
|
|
def _convert(self, model: GraphModule, debug: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None,
|
|
is_standalone_module: bool = False) -> GraphModule:
|
|
""" standalone_module means it a submodule that is not inlined in
|
|
parent module, and will be quantized separately as one unit.
|
|
|
|
Returns a quantized standalone module which accepts float input
|
|
and produces float output.
|
|
"""
|
|
if convert_custom_config_dict is None:
|
|
convert_custom_config_dict = {}
|
|
self.restore_state(model)
|
|
# always run weight observers in the top level forward method
|
|
# for dynamic quant ops or weight only quant ops
|
|
self._run_weight_observers(model)
|
|
|
|
# move to cpu since we only have quantized cpu kernels
|
|
model.eval().cpu()
|
|
self.modules = dict(model.named_modules())
|
|
|
|
custom_module_classes = get_custom_module_class_keys(
|
|
convert_custom_config_dict,
|
|
"observed_to_quantized_custom_module_class")
|
|
assert self.patterns is not None
|
|
matches = self._find_matches(
|
|
model.graph, self.modules, self.patterns,
|
|
custom_module_classes=custom_module_classes)
|
|
|
|
quants = self._find_quants(model.graph, matches)
|
|
|
|
self.quantized_graph = Graph()
|
|
env: Dict[Any, Any] = {}
|
|
quant_env: Dict[Any, Any] = {}
|
|
|
|
graph_inputs = []
|
|
for node in model.graph.nodes:
|
|
if node.op == 'placeholder':
|
|
graph_inputs.append(node.name)
|
|
|
|
def load_non_quantized(n):
|
|
if n.name not in env:
|
|
assert n.name in quant_env, \
|
|
'trying to load float node but did not find ' + \
|
|
'node:' + n.name + \
|
|
' in quantized or non quantized environment, env: ' + \
|
|
str(env) + ' quant_env:' + str(quant_env)
|
|
env[n.name] = Proxy(quant_env[n.name]).dequantize().node
|
|
return env[n.name]
|
|
|
|
def load_quantized(n):
|
|
if n.name not in quant_env:
|
|
assert n.name in env, \
|
|
'trying to load quantized node but did not find node:' + \
|
|
n.name + ' in float environment:' + str(env)
|
|
assert n.name in quants, \
|
|
'did not find quant object for node:' + n.name
|
|
quant = quants[n.name][0]
|
|
quant_env[n.name] = quant.convert(self, env[n.name])
|
|
return quant_env[n.name]
|
|
|
|
def load_x(n):
|
|
assert n.name in env or n.name in quant_env, \
|
|
'node ' + n.name + ' does not exist in either environment'
|
|
if n.name in quant_env:
|
|
return quant_env[n.name]
|
|
else:
|
|
return env[n.name]
|
|
|
|
def load_arg(quantized):
|
|
"""
|
|
Input: quantized, which can be None, list, boolean or tuple
|
|
- if quantized is a list or tuple, then arg should be a list and
|
|
the args with corresponding indexes will be quantized
|
|
- if quantized is a boolean, then all args will be
|
|
quantized/not quantized
|
|
- if quantized is None, then we'll load the node as long as it
|
|
exists
|
|
|
|
Output: fn which takes arg_or_args, and loads them from the
|
|
corresponding environment depending on the value of quantized.
|
|
"""
|
|
assert quantized is None or \
|
|
isinstance(quantized, (tuple, list, bool)), type(quantized)
|
|
|
|
def load_arg_impl(arg_or_args):
|
|
if quantized is None:
|
|
return map_arg(arg_or_args, load_x)
|
|
if isinstance(quantized, bool):
|
|
return map_arg(
|
|
arg_or_args,
|
|
load_quantized if quantized else load_non_quantized)
|
|
elif isinstance(quantized, (tuple, list)):
|
|
assert isinstance(arg_or_args, (tuple, list)), arg_or_args
|
|
loaded_args = []
|
|
# for now, we only support quantizing positional arguments
|
|
for i, a in enumerate(arg_or_args):
|
|
if i in quantized:
|
|
loaded_args.append(map_arg(a, load_quantized))
|
|
else:
|
|
loaded_args.append(map_arg(a, load_non_quantized))
|
|
return type(arg_or_args)(loaded_args)
|
|
return load_arg_impl
|
|
|
|
def is_quantized(node):
|
|
if isinstance(node, Node):
|
|
assert node.name in env or node.name in quant_env, \
|
|
'Expecting node to be in the environment'
|
|
# there might be nodes appearing in both environemnts, but
|
|
# quant_env will take precedence
|
|
if node.name in quant_env:
|
|
return True
|
|
elif node.name in env:
|
|
return False
|
|
elif isinstance(node, list):
|
|
quantized = map(is_quantized, node)
|
|
if all(quantized):
|
|
return True
|
|
elif not any(quantized):
|
|
return False
|
|
else:
|
|
raise Exception(
|
|
"partially quantized inputs in list not handled yet")
|
|
|
|
def is_output_quantized(node) -> bool:
|
|
""" Check if output node is quantized or not """
|
|
assert self.modules is not None
|
|
# by default the output is expected to be quantized
|
|
quantized = True
|
|
|
|
# Need to get correct quantized/non-quantized state for the output
|
|
# of CopyNode
|
|
if type(obj) in [
|
|
CopyNode,
|
|
FixedQParamsOpQuantizeHandler
|
|
]:
|
|
assert node.op in [
|
|
'call_module',
|
|
'call_function',
|
|
'call_method'], \
|
|
'CopyNode of type ' + node.op + ' is not handled'
|
|
quantized = is_quantized(node.args[0])
|
|
|
|
if not activation_is_statically_quantized(qconfig) or \
|
|
not input_output_observed(obj):
|
|
quantized = False
|
|
|
|
return quantized
|
|
|
|
def insert_quantize_node(node):
|
|
""" Given a activation_post_process module call node, insert a
|
|
quantize node"""
|
|
assert self.modules is not None
|
|
observer_module = self.modules[node.target]
|
|
prev_node = node.args[0]
|
|
if observer_module.dtype == torch.float16:
|
|
# activations are not quantized for
|
|
# fp16 dynamic quantization
|
|
# copy the activaiton_post_process node here
|
|
# since we may need it when we insert prepack
|
|
# op for weight of linear, this will be removed
|
|
# later in a separate pass
|
|
env[node.name] = self.quantized_graph.node_copy(
|
|
node, load_non_quantized)
|
|
elif prev_node.name in quant_env:
|
|
# if previous node is already quantized, we'll just remove the
|
|
# activation_post_process
|
|
quant_env[node.name] = quant_env[prev_node.name]
|
|
else:
|
|
# replace activation post process with quantization ops
|
|
root_module = self.modules[""]
|
|
quant_env[node.name] = quantize_node(
|
|
root_module, self.quantized_graph,
|
|
load_non_quantized(node.args[0]), observer_module)
|
|
|
|
# additional state to override inputs to be quantized, if specified
|
|
# by the user
|
|
placeholder_node_seen_cnt = 0
|
|
output_node_seen_cnt = 0
|
|
input_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
|
|
"input_quantized_idxs", [])
|
|
output_quantized_idxs: List[int] = self.prepare_custom_config_dict.get(
|
|
"output_quantized_idxs", [])
|
|
|
|
for node in model.graph.nodes:
|
|
if node.op == 'output':
|
|
cur_output_node_idx = output_node_seen_cnt
|
|
output_node_seen_cnt += 1
|
|
if cur_output_node_idx in output_quantized_idxs:
|
|
# Result are kept quantized if the user specified the
|
|
# output_quantized_idxs override.
|
|
graph_output = map_arg(node.args[0], load_x)
|
|
else:
|
|
graph_output = map_arg(node.args[0], load_non_quantized)
|
|
self.quantized_graph.output(graph_output)
|
|
continue
|
|
root_node, matched, matched_pattern, obj, qconfig = \
|
|
matches.get(node.name, (None, None, None, None, None))
|
|
if root_node is node:
|
|
if qconfig is None:
|
|
result = self.quantized_graph.node_copy(
|
|
node, load_non_quantized)
|
|
quantized = False
|
|
else:
|
|
assert obj is not None
|
|
is_standalone_module_node = (
|
|
node.op == 'call_module' and
|
|
is_observed_standalone_module(
|
|
self.modules[node.target]) # type: ignore
|
|
)
|
|
result = obj.convert(
|
|
self, node, load_arg, debug=debug,
|
|
convert_custom_config_dict=convert_custom_config_dict)
|
|
if is_standalone_module_node:
|
|
quantized = False
|
|
else:
|
|
quantized = is_output_quantized(node)
|
|
|
|
if quantized:
|
|
quant_env[node.name] = result
|
|
else:
|
|
env[node.name] = result
|
|
continue
|
|
elif root_node is not None:
|
|
continue
|
|
|
|
# handle activation post process calls
|
|
if node.op == 'call_module' and \
|
|
is_activation_post_process(self.modules[node.target]):
|
|
insert_quantize_node(node)
|
|
elif node.op == 'placeholder':
|
|
cur_placeholder_node_idx = placeholder_node_seen_cnt
|
|
placeholder_node_seen_cnt += 1
|
|
if cur_placeholder_node_idx in input_quantized_idxs:
|
|
quant_env[node.name] = \
|
|
self.quantized_graph.node_copy(node, load_non_quantized)
|
|
else:
|
|
env[node.name] = \
|
|
self.quantized_graph.node_copy(node, load_non_quantized)
|
|
else:
|
|
# copy quantized or non-quantized node
|
|
env[node.name] = \
|
|
self.quantized_graph.node_copy(node, load_non_quantized)
|
|
|
|
# remove activation post process
|
|
act_post_process_removed_graph = Graph()
|
|
env = {}
|
|
|
|
def load_arg(a): # type: ignore
|
|
return map_arg(a, lambda node: env[node.name])
|
|
for node in self.quantized_graph.nodes:
|
|
if node.op == 'output':
|
|
act_post_process_removed_graph.output(
|
|
map_arg(node.args[0], load_arg))
|
|
continue
|
|
if node.op == 'call_module' and \
|
|
is_activation_post_process(self.modules[node.target]):
|
|
# remove activation post process node
|
|
env[node.name] = env[node.args[0].name]
|
|
else:
|
|
env[node.name] = act_post_process_removed_graph.node_copy(
|
|
node, load_arg)
|
|
|
|
# removes qconfig and activation_post_process modules
|
|
_remove_qconfig(model)
|
|
model = GraphModule(model, act_post_process_removed_graph)
|
|
return model
|
|
|
|
# Trace back from the weight node util we hit getattr, reconstruct the
|
|
# graph module with the traced nodes and run the graph module to pack the
|
|
# weight. then replace the original chain of ops with the packed weight.
|
|
def _fold_weight(self, quantized: GraphModule) -> GraphModule:
|
|
packed_weights = dict()
|
|
# map from folded node name to the prepacked weight name
|
|
folded_nodes = dict()
|
|
# get packed weights
|
|
for node in quantized.graph.nodes:
|
|
if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS:
|
|
nodes_to_fold = collect_producer_nodes(node)
|
|
if nodes_to_fold is not None:
|
|
for node_to_fold in nodes_to_fold:
|
|
folded_nodes[node_to_fold.name] = node
|
|
|
|
prepacking_module = graph_module_from_producer_nodes(
|
|
quantized, nodes_to_fold)
|
|
packed_weight = prepacking_module()
|
|
packed_weights[node.name] = packed_weight
|
|
|
|
# remove folded nodes and replace the prepacking node with getattr
|
|
folded_graph = Graph()
|
|
env: Dict[Any, Any] = {}
|
|
|
|
def load_arg(a):
|
|
return map_arg(a, lambda node: env[node.name])
|
|
get_new_packed_weight_name = \
|
|
get_new_attr_name_with_prefix('_fx_pass_packed_weight_')
|
|
quantized_root = quantized
|
|
quantized_graph = quantized.graph
|
|
for node in quantized_graph.nodes:
|
|
prepack_node = folded_nodes.get(node.name, None)
|
|
if prepack_node is node:
|
|
packed_weight = packed_weights[node.name]
|
|
# add a prepacked attribute to root
|
|
packed_weight_name = get_new_packed_weight_name(quantized_root)
|
|
setattr(quantized_root, packed_weight_name, packed_weight)
|
|
# replace prepack node with a getattr node
|
|
env[node.name] = folded_graph.create_node(
|
|
'get_attr', packed_weight_name, (), {})
|
|
elif prepack_node is not None:
|
|
# remove the foled node
|
|
continue
|
|
else:
|
|
# copy other nodes
|
|
env[node.name] = folded_graph.node_copy(node, load_arg)
|
|
quantized = GraphModule(quantized_root, folded_graph)
|
|
return quantized
|
|
|
|
def convert(self, model: GraphModule, debug: bool = False,
|
|
convert_custom_config_dict: Dict[str, Any] = None,
|
|
is_standalone_module: bool = False) -> GraphModule:
|
|
quantized = self._convert(
|
|
model, debug, convert_custom_config_dict, is_standalone_module)
|
|
if not debug:
|
|
quantized = self._fold_weight(quantized)
|
|
return quantized
|
|
|
|
def _find_matches(
|
|
self, graph: Graph, modules: Dict[str, torch.nn.Module],
|
|
patterns: Dict[Pattern, QuantizeHandler],
|
|
standalone_module_names: List[str] = None,
|
|
standalone_module_classes: List[Callable] = None,
|
|
custom_module_classes: List[Any] = None) -> Dict[str, MatchResult]:
|
|
"""
|
|
Matches the nodes in the input graph to quantization patterns, and
|
|
outputs the information needed to quantize them in future steps.
|
|
|
|
Inputs:
|
|
- graph: an fx.Graph object
|
|
- modules: a mapping of fully qualified module name to instance,
|
|
for example, {'foo': ModuleFoo, ...}
|
|
- patterns: a mapping from a tuple of nodes in reverse order to
|
|
uninitialized QuantizeHandler subclass.
|
|
|
|
Outputs a map of
|
|
node_name ->
|
|
(node, matched_values, matched_pattern, QuantizeHandler instance,
|
|
qconfig)
|
|
|
|
For example, {
|
|
'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
|
|
<CopyNode instance>, QConfig(...)),
|
|
...
|
|
}
|
|
"""
|
|
if custom_module_classes is None:
|
|
custom_module_classes = []
|
|
|
|
if standalone_module_classes is None:
|
|
standalone_module_classes = []
|
|
|
|
if standalone_module_names is None:
|
|
standalone_module_names = []
|
|
|
|
match_map: Dict[str, MatchResult] = {}
|
|
all_matched = set()
|
|
|
|
def record_match(pattern, node, matched):
|
|
if isinstance(pattern, tuple):
|
|
s, *args = pattern
|
|
record_match(s, node, matched)
|
|
if pattern[0] is not getattr:
|
|
for subpattern, arg in zip(args, node.args):
|
|
record_match(subpattern, arg, matched)
|
|
else:
|
|
matched.append(node)
|
|
|
|
assert self.qconfig_map is not None
|
|
for node in reversed(graph.nodes):
|
|
if node.name not in match_map and node.name not in all_matched:
|
|
for pattern, value in patterns.items():
|
|
if is_match(modules, node, pattern):
|
|
matched: List[Any] = []
|
|
record_match(pattern, node, matched)
|
|
for n in matched:
|
|
match_map[n.name] = (
|
|
node, matched, pattern, value(self, node), # type: ignore
|
|
self.qconfig_map[n.name])
|
|
all_matched.add(n.name)
|
|
# break after finding the first match
|
|
break
|
|
|
|
# add custom module instances to the match result
|
|
assert self.modules is not None
|
|
for node in graph.nodes:
|
|
if node.op == 'call_module' and \
|
|
type(self.modules[node.target]) in custom_module_classes:
|
|
custom_module_qconfig = self.qconfig_map[node.name]
|
|
match_map[node.name] = (
|
|
node, [node], None, CustomModuleQuantizeHandler(self, node),
|
|
custom_module_qconfig)
|
|
|
|
def is_standalone_module(node_target):
|
|
assert self.modules is not None
|
|
return (
|
|
node_target in standalone_module_names or # type: ignore
|
|
type(self.modules[node_target]) in standalone_module_classes # type: ignore
|
|
)
|
|
|
|
# add standalone modules to the match
|
|
for node in graph.nodes:
|
|
if node.op == 'call_module' and \
|
|
(is_standalone_module(node.target) or
|
|
is_observed_standalone_module(self.modules[node.target])):
|
|
# add node to matched nodes
|
|
custom_module_qconfig = self.qconfig_map[node.name]
|
|
match_map[node.name] = (
|
|
node, [node], None,
|
|
StandaloneModuleQuantizeHandler(self, node),
|
|
custom_module_qconfig)
|
|
|
|
return match_map
|
|
|
|
def _find_quants(self, graph: Graph, matches: Dict[str, MatchResult],
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Takes the nodes in the input graph and pending matches, and finds and
|
|
returns the input and output nodes which need to be quantized.
|
|
|
|
Inputs:
|
|
- graph: an fx.Graph object
|
|
- matches: output of self._find_matches function
|
|
|
|
Outputs a map of
|
|
node_name -> (QuantizeHandler instance (always DefaultQuantizeHandler),
|
|
activation_post_process (observer/fake_quantize module) constructor)
|
|
"""
|
|
quants: Dict[str, Any] = {}
|
|
|
|
def visit(node, matched_pattern, qconfig):
|
|
def visit_arg(arg):
|
|
is_weight = False
|
|
if isinstance(node, Node) and node.op == 'call_function' and \
|
|
node.target in WEIGHT_INDEX_DICT:
|
|
for i, node_arg in enumerate(node.args):
|
|
if arg is node_arg and i in \
|
|
WEIGHT_INDEX_DICT[node.target]: # type: ignore
|
|
is_weight = True
|
|
if qconfig is not None and \
|
|
(activation_is_statically_quantized(qconfig) or is_weight):
|
|
act_post_process_ctr = qconfig.weight if is_weight else \
|
|
qconfig.activation
|
|
quants[arg.name] = (
|
|
DefaultQuantizeHandler(self, arg), qconfig, is_weight)
|
|
# overwrite the constructor from qconfig
|
|
act_post_process_ctr = \
|
|
get_default_output_activation_post_process_map().get(
|
|
matched_pattern,
|
|
act_post_process_ctr)
|
|
# overwrite previous activation post process constructor if
|
|
# necessary
|
|
quants[arg.name] = (
|
|
DefaultQuantizeHandler(self, arg), act_post_process_ctr)
|
|
return visit_arg
|
|
|
|
for node in graph.nodes:
|
|
if node.name in matches:
|
|
root_node, matched_nodes, matched_pattern, quantize_handler, \
|
|
qconfig = matches[node.name]
|
|
# don't attach observer/fake_quant for CopyNode
|
|
if isinstance(quantize_handler, CopyNode):
|
|
qconfig = None
|
|
if root_node is node and \
|
|
input_output_observed(quantize_handler):
|
|
# matched_nodes[-1] is the first op in the sequence and
|
|
# matched_nodes[0] is the last op in the sequence
|
|
# inputs
|
|
# matched_pattern is set to None for inputs because
|
|
# we only want to select QuantizeHandler object based
|
|
# on pattern for output, inputs will always use
|
|
# DefaultQuantizeHandler
|
|
map_arg(matched_nodes[-1].args, visit(matched_nodes[-1],
|
|
None, qconfig))
|
|
map_arg(matched_nodes[-1].kwargs, visit(matched_nodes[-1],
|
|
None, qconfig))
|
|
|
|
# output
|
|
# we don't insert observer for output of standalone module
|
|
if not isinstance(
|
|
quantize_handler, StandaloneModuleQuantizeHandler):
|
|
# passing in matched_pattern here so that we can
|
|
# customize activation_post_process constructor for
|
|
# output based on the pattern, e.g.
|
|
# for sigmoid op we'll use
|
|
# default_affine_fixed_qparam_fake_quant
|
|
map_arg(matched_nodes[0],
|
|
visit(None, matched_pattern, qconfig))
|
|
return quants
|