fx quant: clean up functions in _prepare (#48773)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48773

Makes util functions in `_prepare` have no side effects,
all dependencies are now in arguments.

Note: arg names are added in order as they appeared in function
code. It's not the most readable, but the lowest risk. This can
be cleaned up in future PRs if needed.

```
python test/test_quantization.py TestQuantizeFx
```

Test Plan: Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D25295839

fbshipit-source-id: 60c687f6b64924473f969541c8703118e4f7d16e
This commit is contained in:
Vasiliy Kuznetsov 2020-12-03 19:21:33 -08:00 committed by Facebook GitHub Bot
parent 536352e86f
commit c98c617b44

View File

@ -57,7 +57,7 @@ from collections import OrderedDict
import warnings
import re
from typing import Optional, Dict, Any, List, Union, Tuple
from typing import Optional, Dict, Any, List, Union, Tuple, Set
# Define helper types
@ -246,6 +246,162 @@ def get_qconfig(modules, qconfig_dict, module_name, global_qconfig):
qconfig_dict, module_name, module_name_regex_qconfig)
return module_name_qconfig
def insert_observer(
node, observer, model_device, model,
activation_post_process_map, env, observed_graph, load_arg,
observed_node_names_set):
"""Insert observer for node by modifying the observed_graph and
attach observer module to the model
Args:
node: Node
observer: observer/fake_quantize module instance
"""
# respect device affinity when adding observers
if model_device:
observer.to(model_device)
# add observer module as attribute
prefix = node.name + '_activation_post_process_'
get_new_observer_name = get_new_attr_name_with_prefix(prefix)
observer_name = get_new_observer_name(model)
setattr(model, observer_name, observer)
# put observer instance activation_post_process map
assert activation_post_process_map is not None
activation_post_process_map[node.name] = observer
# insert observer call
env[node.name] = observed_graph.create_node(
'call_module', observer_name, (load_arg(node),), {})
observed_node_names_set.add(node.name)
def insert_observer_for_special_module(
quantize_handler, modules, prepare_custom_config_dict, qconfig,
node):
""" Insert observer for custom module and standalone module
Returns: standalone_module_input_idxs: the indexs for inputs that
needs to be observed by parent module
"""
assert modules is not None
if isinstance(quantize_handler, CustomModuleQuantizeHandler):
custom_module = modules[node.target]
custom_module_class_mapping = prepare_custom_config_dict.get(
"float_to_observed_custom_module_class", {})
observed_custom_module_class = \
get_swapped_custom_module_class(
custom_module, custom_module_class_mapping, qconfig)
observed_custom_module = \
observed_custom_module_class.from_float(custom_module)
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name, observed_custom_module)
elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
# observe standalone module
standalone_module = modules[node.target]
prepare = \
torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore
observed_standalone_module = \
prepare(standalone_module, {"": qconfig})
observed_standalone_module.qconfig = qconfig
observed_standalone_module = mark_observed_standalone_module(
observed_standalone_module)
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name,
observed_standalone_module)
modules[node.target] = observed_standalone_module
def insert_observer_for_output_of_the_node(
node,
quantize_handler,
qconfig,
modules,
model,
pattern,
model_device,
activation_post_process_map,
env,
observed_graph,
load_arg,
observed_node_names_set,
matched_nodes):
""" Insert observer/fake_quantize module for output of the observed
module if needed
"""
# don't need to insert observer for output if activation does not
# need to be statically quantized
assert modules is not None
if activation_is_statically_quantized(qconfig):
if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) \
and model.training:
# we only insert fake quantize module in qat
assert pattern is not None
activation_post_process_ctr = \
get_default_output_activation_post_process_map().get(
pattern, None)
assert activation_post_process_ctr is not None, \
"activation_post_process constructor not provided " + \
"for pattern:" + str(pattern)
insert_observer(
node, activation_post_process_ctr(), model_device,
model, activation_post_process_map, env, observed_graph,
load_arg, observed_node_names_set)
elif (isinstance(quantize_handler,
FixedQParamsOpQuantizeHandler) and
not model.training) or \
isinstance(quantize_handler, CopyNode):
# inserting observers for output of observed module, or
# mark the output as observed
assert node.op in [
'call_module',
'call_function',
'call_method'], \
'CopyNode of type ' + node.op + ' is not handled'
def is_observed(input_arg):
if isinstance(input_arg, Node):
return input_arg.name in observed_node_names_set
elif isinstance(input_arg, list):
return all(map(is_observed, input_arg))
# propagate observed property from input
if is_observed(node.args[0]):
observed_node_names_set.add(node.name)
elif ((isinstance(quantize_handler, Add) or
isinstance(quantize_handler, Mul)) and
quantize_handler.num_node_args == 1):
assert matched_nodes is not None
input_node = matched_nodes[-1] # first node in the sequence
def input_is_observed(arg):
return (isinstance(arg, Node) and
arg.name in observed_node_names_set)
# This is checking if one of the argument of add/mul
# is an observed node
# If both of the inputs are number,
# we will not consider the output to be observed
if (input_is_observed(input_node.args[0]) or
input_is_observed(input_node.args[1])):
observed_node_names_set.add(node.name)
elif isinstance(quantize_handler,
StandaloneModuleQuantizeHandler):
# output is observed in the standalone module
return
elif (quantize_handler.all_node_args and
input_output_observed(quantize_handler)):
# observer for outputs
new_observer = qconfig.activation()
insert_observer(
node, new_observer, model_device, model,
activation_post_process_map, env, observed_graph,
load_arg, observed_node_names_set)
def insert_observer_for_input_arg_of_observed_node(
node, observed_node_names_set, quants,
model_device, model, activation_post_process_map, env, observed_graph,
load_arg):
if node.name not in observed_node_names_set and node.name in quants:
_, activation_post_process_ctr = quants[node.name]
if activation_post_process_ctr is not None:
insert_observer(
node, activation_post_process_ctr(),
model_device, model, activation_post_process_map,
env, observed_graph, load_arg, observed_node_names_set)
# A dictionary for querying the weight index for a given op
WEIGHT_INDEX_DICT = {
torch.nn.functional.conv2d : [1],
@ -388,7 +544,7 @@ class Quantizer:
self.activation_post_process_map = dict()
env: Dict[Any, Any] = {}
observed_graph = Graph()
observed_node_names_set = set()
observed_node_names_set: Set[str] = set()
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
@ -404,140 +560,6 @@ class Quantizer:
'activation_post_process_')
model_device = assert_and_get_unique_device(model)
def insert_observer(node, observer):
"""Insert observer for node by modifying the observed_graph and
attach observer module to the model
Args:
node: Node
observer: observer/fake_quantize module instance
"""
# respect device affinity when adding observers
if model_device:
observer.to(model_device)
# add observer module as attribute
prefix = node.name + '_activation_post_process_'
get_new_observer_name = get_new_attr_name_with_prefix(prefix)
observer_name = get_new_observer_name(model)
setattr(model, observer_name, observer)
# put observer instance activation_post_process map
assert self.activation_post_process_map is not None
self.activation_post_process_map[node.name] = observer
# insert observer call
env[node.name] = observed_graph.create_node(
'call_module', observer_name, (load_arg(node),), {})
observed_node_names_set.add(node.name)
def insert_observer_for_special_module(quantize_handler):
""" Insert observer for custom module and standalone module
Returns: standalone_module_input_idxs: the indexs for inputs that
needs to be observed by parent module
"""
assert self.modules is not None
if isinstance(quantize_handler, CustomModuleQuantizeHandler):
custom_module = self.modules[node.target]
custom_module_class_mapping = prepare_custom_config_dict.get(
"float_to_observed_custom_module_class", {})
observed_custom_module_class = \
get_swapped_custom_module_class(
custom_module, custom_module_class_mapping, qconfig)
observed_custom_module = \
observed_custom_module_class.from_float(custom_module)
parent_name, name = _parent_name(node.target)
setattr(self.modules[parent_name], name, observed_custom_module)
elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
# observe standalone module
standalone_module = self.modules[node.target]
prepare = \
torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore
observed_standalone_module = \
prepare(standalone_module, {"": qconfig})
observed_standalone_module.qconfig = qconfig
observed_standalone_module = mark_observed_standalone_module(
observed_standalone_module)
parent_name, name = _parent_name(node.target)
setattr(self.modules[parent_name], name,
observed_standalone_module)
self.modules[node.target] = observed_standalone_module
def insert_observer_for_output_of_the_node(
node,
quantize_handler,
qconfig):
""" Insert observer/fake_quantize module for output of the observed
module if needed
"""
# don't need to insert observer for output if activation does not
# need to be statically quantized
assert self.modules is not None
if activation_is_statically_quantized(qconfig):
if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) \
and model.training:
# we only insert fake quantize module in qat
assert pattern is not None
activation_post_process_ctr = \
get_default_output_activation_post_process_map().get(
pattern, None)
assert activation_post_process_ctr is not None, \
"activation_post_process constructor not provided " + \
"for pattern:" + str(pattern)
insert_observer(node, activation_post_process_ctr())
elif (isinstance(quantize_handler,
FixedQParamsOpQuantizeHandler) and
not model.training) or \
isinstance(quantize_handler, CopyNode):
# inserting observers for output of observed module, or
# mark the output as observed
assert node.op in [
'call_module',
'call_function',
'call_method'], \
'CopyNode of type ' + node.op + ' is not handled'
def is_observed(input_arg):
if isinstance(input_arg, Node):
return input_arg.name in observed_node_names_set
elif isinstance(input_arg, list):
return all(map(is_observed, input_arg))
# propagate observed property from input
if is_observed(node.args[0]):
observed_node_names_set.add(node.name)
elif ((isinstance(quantize_handler, Add) or
isinstance(quantize_handler, Mul)) and
quantize_handler.num_node_args == 1):
assert matched_nodes is not None
input_node = matched_nodes[-1] # first node in the sequence
def input_is_observed(arg):
return (isinstance(arg, Node) and
arg.name in observed_node_names_set)
# This is checking if one of the argument of add/mul
# is an observed node
# If both of the inputs are number,
# we will not consider the output to be observed
if (input_is_observed(input_node.args[0]) or
input_is_observed(input_node.args[1])):
observed_node_names_set.add(node.name)
elif isinstance(quantize_handler,
StandaloneModuleQuantizeHandler):
# output is observed in the standalone module
return
elif (quantize_handler.all_node_args and
input_output_observed(quantize_handler)):
# observer for outputs
new_observer = qconfig.activation()
insert_observer(node, new_observer)
def insert_observer_for_input_arg_of_observed_node(arg):
"""
Input:
arg: input arg node for another observed node, e.g.
input activaiton for functional linear node
"""
if node.name not in observed_node_names_set and node.name in quants:
_, activation_post_process_ctr = quants[node.name]
if activation_post_process_ctr is not None:
insert_observer(node, activation_post_process_ctr())
result_node : Optional[Node] = None
for node in model.graph.nodes:
if node.op == 'output':
@ -556,12 +578,20 @@ class Quantizer:
# index for input of custom module that needs to be observed in
# parent
if qconfig is not None:
insert_observer_for_special_module(obj)
insert_observer_for_special_module(
obj, self.modules, prepare_custom_config_dict, qconfig,
node)
insert_observer_for_output_of_the_node(
node, obj, qconfig)
node, obj, qconfig, self.modules, model, pattern,
model_device, self.activation_post_process_map, env,
observed_graph, load_arg, observed_node_names_set,
matched_nodes)
else:
env[node.name] = observed_graph.node_copy(node, load_arg)
insert_observer_for_input_arg_of_observed_node(node)
insert_observer_for_input_arg_of_observed_node(
node, observed_node_names_set, quants,
model_device, model, self.activation_post_process_map, env,
observed_graph, load_arg)
model = GraphModule(model, observed_graph)