mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
fx quant: clean up functions in _prepare (#48773)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48773 Makes util functions in `_prepare` have no side effects, all dependencies are now in arguments. Note: arg names are added in order as they appeared in function code. It's not the most readable, but the lowest risk. This can be cleaned up in future PRs if needed. ``` python test/test_quantization.py TestQuantizeFx ``` Test Plan: Imported from OSS Reviewed By: jerryzh168 Differential Revision: D25295839 fbshipit-source-id: 60c687f6b64924473f969541c8703118e4f7d16e
This commit is contained in:
parent
536352e86f
commit
c98c617b44
|
|
@ -57,7 +57,7 @@ from collections import OrderedDict
|
|||
import warnings
|
||||
import re
|
||||
|
||||
from typing import Optional, Dict, Any, List, Union, Tuple
|
||||
from typing import Optional, Dict, Any, List, Union, Tuple, Set
|
||||
|
||||
# Define helper types
|
||||
|
||||
|
|
@ -246,6 +246,162 @@ def get_qconfig(modules, qconfig_dict, module_name, global_qconfig):
|
|||
qconfig_dict, module_name, module_name_regex_qconfig)
|
||||
return module_name_qconfig
|
||||
|
||||
def insert_observer(
|
||||
node, observer, model_device, model,
|
||||
activation_post_process_map, env, observed_graph, load_arg,
|
||||
observed_node_names_set):
|
||||
"""Insert observer for node by modifying the observed_graph and
|
||||
attach observer module to the model
|
||||
Args:
|
||||
node: Node
|
||||
observer: observer/fake_quantize module instance
|
||||
"""
|
||||
# respect device affinity when adding observers
|
||||
if model_device:
|
||||
observer.to(model_device)
|
||||
# add observer module as attribute
|
||||
prefix = node.name + '_activation_post_process_'
|
||||
get_new_observer_name = get_new_attr_name_with_prefix(prefix)
|
||||
observer_name = get_new_observer_name(model)
|
||||
setattr(model, observer_name, observer)
|
||||
# put observer instance activation_post_process map
|
||||
assert activation_post_process_map is not None
|
||||
activation_post_process_map[node.name] = observer
|
||||
# insert observer call
|
||||
env[node.name] = observed_graph.create_node(
|
||||
'call_module', observer_name, (load_arg(node),), {})
|
||||
observed_node_names_set.add(node.name)
|
||||
|
||||
def insert_observer_for_special_module(
|
||||
quantize_handler, modules, prepare_custom_config_dict, qconfig,
|
||||
node):
|
||||
""" Insert observer for custom module and standalone module
|
||||
Returns: standalone_module_input_idxs: the indexs for inputs that
|
||||
needs to be observed by parent module
|
||||
"""
|
||||
assert modules is not None
|
||||
if isinstance(quantize_handler, CustomModuleQuantizeHandler):
|
||||
custom_module = modules[node.target]
|
||||
custom_module_class_mapping = prepare_custom_config_dict.get(
|
||||
"float_to_observed_custom_module_class", {})
|
||||
observed_custom_module_class = \
|
||||
get_swapped_custom_module_class(
|
||||
custom_module, custom_module_class_mapping, qconfig)
|
||||
observed_custom_module = \
|
||||
observed_custom_module_class.from_float(custom_module)
|
||||
parent_name, name = _parent_name(node.target)
|
||||
setattr(modules[parent_name], name, observed_custom_module)
|
||||
elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
|
||||
# observe standalone module
|
||||
standalone_module = modules[node.target]
|
||||
prepare = \
|
||||
torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore
|
||||
observed_standalone_module = \
|
||||
prepare(standalone_module, {"": qconfig})
|
||||
observed_standalone_module.qconfig = qconfig
|
||||
observed_standalone_module = mark_observed_standalone_module(
|
||||
observed_standalone_module)
|
||||
parent_name, name = _parent_name(node.target)
|
||||
setattr(modules[parent_name], name,
|
||||
observed_standalone_module)
|
||||
modules[node.target] = observed_standalone_module
|
||||
|
||||
def insert_observer_for_output_of_the_node(
|
||||
node,
|
||||
quantize_handler,
|
||||
qconfig,
|
||||
modules,
|
||||
model,
|
||||
pattern,
|
||||
model_device,
|
||||
activation_post_process_map,
|
||||
env,
|
||||
observed_graph,
|
||||
load_arg,
|
||||
observed_node_names_set,
|
||||
matched_nodes):
|
||||
""" Insert observer/fake_quantize module for output of the observed
|
||||
module if needed
|
||||
"""
|
||||
# don't need to insert observer for output if activation does not
|
||||
# need to be statically quantized
|
||||
assert modules is not None
|
||||
if activation_is_statically_quantized(qconfig):
|
||||
if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) \
|
||||
and model.training:
|
||||
# we only insert fake quantize module in qat
|
||||
assert pattern is not None
|
||||
activation_post_process_ctr = \
|
||||
get_default_output_activation_post_process_map().get(
|
||||
pattern, None)
|
||||
assert activation_post_process_ctr is not None, \
|
||||
"activation_post_process constructor not provided " + \
|
||||
"for pattern:" + str(pattern)
|
||||
insert_observer(
|
||||
node, activation_post_process_ctr(), model_device,
|
||||
model, activation_post_process_map, env, observed_graph,
|
||||
load_arg, observed_node_names_set)
|
||||
elif (isinstance(quantize_handler,
|
||||
FixedQParamsOpQuantizeHandler) and
|
||||
not model.training) or \
|
||||
isinstance(quantize_handler, CopyNode):
|
||||
# inserting observers for output of observed module, or
|
||||
# mark the output as observed
|
||||
assert node.op in [
|
||||
'call_module',
|
||||
'call_function',
|
||||
'call_method'], \
|
||||
'CopyNode of type ' + node.op + ' is not handled'
|
||||
|
||||
def is_observed(input_arg):
|
||||
if isinstance(input_arg, Node):
|
||||
return input_arg.name in observed_node_names_set
|
||||
elif isinstance(input_arg, list):
|
||||
return all(map(is_observed, input_arg))
|
||||
# propagate observed property from input
|
||||
if is_observed(node.args[0]):
|
||||
observed_node_names_set.add(node.name)
|
||||
elif ((isinstance(quantize_handler, Add) or
|
||||
isinstance(quantize_handler, Mul)) and
|
||||
quantize_handler.num_node_args == 1):
|
||||
assert matched_nodes is not None
|
||||
input_node = matched_nodes[-1] # first node in the sequence
|
||||
|
||||
def input_is_observed(arg):
|
||||
return (isinstance(arg, Node) and
|
||||
arg.name in observed_node_names_set)
|
||||
# This is checking if one of the argument of add/mul
|
||||
# is an observed node
|
||||
# If both of the inputs are number,
|
||||
# we will not consider the output to be observed
|
||||
if (input_is_observed(input_node.args[0]) or
|
||||
input_is_observed(input_node.args[1])):
|
||||
observed_node_names_set.add(node.name)
|
||||
elif isinstance(quantize_handler,
|
||||
StandaloneModuleQuantizeHandler):
|
||||
# output is observed in the standalone module
|
||||
return
|
||||
elif (quantize_handler.all_node_args and
|
||||
input_output_observed(quantize_handler)):
|
||||
# observer for outputs
|
||||
new_observer = qconfig.activation()
|
||||
insert_observer(
|
||||
node, new_observer, model_device, model,
|
||||
activation_post_process_map, env, observed_graph,
|
||||
load_arg, observed_node_names_set)
|
||||
|
||||
def insert_observer_for_input_arg_of_observed_node(
|
||||
node, observed_node_names_set, quants,
|
||||
model_device, model, activation_post_process_map, env, observed_graph,
|
||||
load_arg):
|
||||
if node.name not in observed_node_names_set and node.name in quants:
|
||||
_, activation_post_process_ctr = quants[node.name]
|
||||
if activation_post_process_ctr is not None:
|
||||
insert_observer(
|
||||
node, activation_post_process_ctr(),
|
||||
model_device, model, activation_post_process_map,
|
||||
env, observed_graph, load_arg, observed_node_names_set)
|
||||
|
||||
# A dictionary for querying the weight index for a given op
|
||||
WEIGHT_INDEX_DICT = {
|
||||
torch.nn.functional.conv2d : [1],
|
||||
|
|
@ -388,7 +544,7 @@ class Quantizer:
|
|||
self.activation_post_process_map = dict()
|
||||
env: Dict[Any, Any] = {}
|
||||
observed_graph = Graph()
|
||||
observed_node_names_set = set()
|
||||
observed_node_names_set: Set[str] = set()
|
||||
|
||||
def load_arg(a):
|
||||
return map_arg(a, lambda node: env[node.name])
|
||||
|
|
@ -404,140 +560,6 @@ class Quantizer:
|
|||
'activation_post_process_')
|
||||
model_device = assert_and_get_unique_device(model)
|
||||
|
||||
def insert_observer(node, observer):
|
||||
"""Insert observer for node by modifying the observed_graph and
|
||||
attach observer module to the model
|
||||
Args:
|
||||
node: Node
|
||||
observer: observer/fake_quantize module instance
|
||||
"""
|
||||
# respect device affinity when adding observers
|
||||
if model_device:
|
||||
observer.to(model_device)
|
||||
# add observer module as attribute
|
||||
prefix = node.name + '_activation_post_process_'
|
||||
get_new_observer_name = get_new_attr_name_with_prefix(prefix)
|
||||
observer_name = get_new_observer_name(model)
|
||||
setattr(model, observer_name, observer)
|
||||
# put observer instance activation_post_process map
|
||||
assert self.activation_post_process_map is not None
|
||||
self.activation_post_process_map[node.name] = observer
|
||||
# insert observer call
|
||||
env[node.name] = observed_graph.create_node(
|
||||
'call_module', observer_name, (load_arg(node),), {})
|
||||
observed_node_names_set.add(node.name)
|
||||
|
||||
def insert_observer_for_special_module(quantize_handler):
|
||||
""" Insert observer for custom module and standalone module
|
||||
Returns: standalone_module_input_idxs: the indexs for inputs that
|
||||
needs to be observed by parent module
|
||||
"""
|
||||
assert self.modules is not None
|
||||
if isinstance(quantize_handler, CustomModuleQuantizeHandler):
|
||||
custom_module = self.modules[node.target]
|
||||
custom_module_class_mapping = prepare_custom_config_dict.get(
|
||||
"float_to_observed_custom_module_class", {})
|
||||
observed_custom_module_class = \
|
||||
get_swapped_custom_module_class(
|
||||
custom_module, custom_module_class_mapping, qconfig)
|
||||
observed_custom_module = \
|
||||
observed_custom_module_class.from_float(custom_module)
|
||||
parent_name, name = _parent_name(node.target)
|
||||
setattr(self.modules[parent_name], name, observed_custom_module)
|
||||
elif isinstance(quantize_handler, StandaloneModuleQuantizeHandler):
|
||||
# observe standalone module
|
||||
standalone_module = self.modules[node.target]
|
||||
prepare = \
|
||||
torch.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore
|
||||
observed_standalone_module = \
|
||||
prepare(standalone_module, {"": qconfig})
|
||||
observed_standalone_module.qconfig = qconfig
|
||||
observed_standalone_module = mark_observed_standalone_module(
|
||||
observed_standalone_module)
|
||||
parent_name, name = _parent_name(node.target)
|
||||
setattr(self.modules[parent_name], name,
|
||||
observed_standalone_module)
|
||||
self.modules[node.target] = observed_standalone_module
|
||||
|
||||
def insert_observer_for_output_of_the_node(
|
||||
node,
|
||||
quantize_handler,
|
||||
qconfig):
|
||||
""" Insert observer/fake_quantize module for output of the observed
|
||||
module if needed
|
||||
"""
|
||||
# don't need to insert observer for output if activation does not
|
||||
# need to be statically quantized
|
||||
assert self.modules is not None
|
||||
if activation_is_statically_quantized(qconfig):
|
||||
if isinstance(quantize_handler, FixedQParamsOpQuantizeHandler) \
|
||||
and model.training:
|
||||
# we only insert fake quantize module in qat
|
||||
assert pattern is not None
|
||||
activation_post_process_ctr = \
|
||||
get_default_output_activation_post_process_map().get(
|
||||
pattern, None)
|
||||
assert activation_post_process_ctr is not None, \
|
||||
"activation_post_process constructor not provided " + \
|
||||
"for pattern:" + str(pattern)
|
||||
insert_observer(node, activation_post_process_ctr())
|
||||
elif (isinstance(quantize_handler,
|
||||
FixedQParamsOpQuantizeHandler) and
|
||||
not model.training) or \
|
||||
isinstance(quantize_handler, CopyNode):
|
||||
# inserting observers for output of observed module, or
|
||||
# mark the output as observed
|
||||
assert node.op in [
|
||||
'call_module',
|
||||
'call_function',
|
||||
'call_method'], \
|
||||
'CopyNode of type ' + node.op + ' is not handled'
|
||||
|
||||
def is_observed(input_arg):
|
||||
if isinstance(input_arg, Node):
|
||||
return input_arg.name in observed_node_names_set
|
||||
elif isinstance(input_arg, list):
|
||||
return all(map(is_observed, input_arg))
|
||||
# propagate observed property from input
|
||||
if is_observed(node.args[0]):
|
||||
observed_node_names_set.add(node.name)
|
||||
elif ((isinstance(quantize_handler, Add) or
|
||||
isinstance(quantize_handler, Mul)) and
|
||||
quantize_handler.num_node_args == 1):
|
||||
assert matched_nodes is not None
|
||||
input_node = matched_nodes[-1] # first node in the sequence
|
||||
|
||||
def input_is_observed(arg):
|
||||
return (isinstance(arg, Node) and
|
||||
arg.name in observed_node_names_set)
|
||||
# This is checking if one of the argument of add/mul
|
||||
# is an observed node
|
||||
# If both of the inputs are number,
|
||||
# we will not consider the output to be observed
|
||||
if (input_is_observed(input_node.args[0]) or
|
||||
input_is_observed(input_node.args[1])):
|
||||
observed_node_names_set.add(node.name)
|
||||
elif isinstance(quantize_handler,
|
||||
StandaloneModuleQuantizeHandler):
|
||||
# output is observed in the standalone module
|
||||
return
|
||||
elif (quantize_handler.all_node_args and
|
||||
input_output_observed(quantize_handler)):
|
||||
# observer for outputs
|
||||
new_observer = qconfig.activation()
|
||||
insert_observer(node, new_observer)
|
||||
|
||||
def insert_observer_for_input_arg_of_observed_node(arg):
|
||||
"""
|
||||
Input:
|
||||
arg: input arg node for another observed node, e.g.
|
||||
input activaiton for functional linear node
|
||||
"""
|
||||
if node.name not in observed_node_names_set and node.name in quants:
|
||||
_, activation_post_process_ctr = quants[node.name]
|
||||
if activation_post_process_ctr is not None:
|
||||
insert_observer(node, activation_post_process_ctr())
|
||||
|
||||
result_node : Optional[Node] = None
|
||||
for node in model.graph.nodes:
|
||||
if node.op == 'output':
|
||||
|
|
@ -556,12 +578,20 @@ class Quantizer:
|
|||
# index for input of custom module that needs to be observed in
|
||||
# parent
|
||||
if qconfig is not None:
|
||||
insert_observer_for_special_module(obj)
|
||||
insert_observer_for_special_module(
|
||||
obj, self.modules, prepare_custom_config_dict, qconfig,
|
||||
node)
|
||||
insert_observer_for_output_of_the_node(
|
||||
node, obj, qconfig)
|
||||
node, obj, qconfig, self.modules, model, pattern,
|
||||
model_device, self.activation_post_process_map, env,
|
||||
observed_graph, load_arg, observed_node_names_set,
|
||||
matched_nodes)
|
||||
else:
|
||||
env[node.name] = observed_graph.node_copy(node, load_arg)
|
||||
insert_observer_for_input_arg_of_observed_node(node)
|
||||
insert_observer_for_input_arg_of_observed_node(
|
||||
node, observed_node_names_set, quants,
|
||||
model_device, model, self.activation_post_process_map, env,
|
||||
observed_graph, load_arg)
|
||||
|
||||
|
||||
model = GraphModule(model, observed_graph)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user