mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: att, since we need to reuse the tracer in some other places Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D37435748](https://our.internmc.facebook.com/intern/diff/D37435748) Pull Request resolved: https://github.com/pytorch/pytorch/pull/80268 Approved by: https://github.com/vkuzo
643 lines
25 KiB
Python
643 lines
25 KiB
Python
import copy
|
|
import re
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.ao.quantization import QuantType
|
|
from torch.ao.quantization.utils import is_per_tensor, is_per_channel
|
|
from torch.ao.quantization.quantize import is_activation_post_process
|
|
|
|
from torch.fx import GraphModule, map_arg
|
|
|
|
from torch.fx.graph import (
|
|
Graph,
|
|
Node,
|
|
)
|
|
from .custom_config import PrepareCustomConfig
|
|
|
|
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type
|
|
from collections import namedtuple
|
|
import operator
|
|
import warnings
|
|
|
|
# TODO: revisit this list. Many helper methods shouldn't be public
|
|
__all__ = [
|
|
"all_node_args_except_first",
|
|
"all_node_args_have_no_tensors",
|
|
"assert_and_get_unique_device",
|
|
"BIAS_INDEX_DICT",
|
|
"collect_producer_nodes",
|
|
"create_getattr_from_value",
|
|
"create_node_from_old_node_preserve_meta",
|
|
"create_qparam_nodes",
|
|
"EMPTY_ARG_DICT",
|
|
"get_custom_module_class_keys",
|
|
"get_linear_prepack_op_for_dtype",
|
|
"get_new_attr_name_with_prefix",
|
|
"get_non_observable_arg_indexes_and_types",
|
|
"get_per_tensor_qparams",
|
|
"get_qconv_op",
|
|
"get_qconv_prepack_op",
|
|
"get_quantize_node_info",
|
|
"graph_module_from_producer_nodes",
|
|
"graph_pretty_str",
|
|
"is_get_tensor_info_node",
|
|
"maybe_get_next_module",
|
|
"NodeInfo",
|
|
"node_return_type_is_int",
|
|
"NON_OBSERVABLE_ARG_DICT",
|
|
"NON_QUANTIZABLE_WEIGHT_OPS",
|
|
"quantize_node",
|
|
"return_arg_list",
|
|
"WEIGHT_INDEX_DICT",
|
|
"get_skipped_module_name_and_classes",
|
|
]
|
|
|
|
|
|
# A dictionary for querying the weight index for a given op
|
|
WEIGHT_INDEX_DICT = {
|
|
torch.nn.functional.conv1d : [1],
|
|
torch.nn.functional.conv2d : [1],
|
|
torch.nn.functional.conv3d : [1],
|
|
torch.nn.functional.linear : [1],
|
|
torch.nn.functional.layer_norm : [2],
|
|
torch.nn.functional.group_norm : [2],
|
|
torch.nn.functional.instance_norm : [3],
|
|
}
|
|
|
|
NON_QUANTIZABLE_WEIGHT_OPS = {torch.nn.functional.layer_norm, torch.nn.functional.group_norm, torch.nn.functional.instance_norm}
|
|
|
|
BIAS_INDEX_DICT = {
|
|
torch.nn.functional.conv1d : [2],
|
|
torch.nn.functional.conv2d : [2],
|
|
torch.nn.functional.conv3d : [2],
|
|
torch.nn.functional.linear : [2],
|
|
torch.nn.functional.layer_norm : [3],
|
|
torch.nn.functional.group_norm : [3],
|
|
torch.nn.functional.instance_norm : [4],
|
|
}
|
|
|
|
def graph_pretty_str(g, shorten=True) -> str:
|
|
"""Returns a printable representation of the ops in the graph of g.
|
|
If shorten is True, tries to abbreviate fields.
|
|
"""
|
|
built_in_func_re = re.compile('<built-in function (.*)>')
|
|
built_in_meth_re = re.compile('<built-in method (.*) of type.*>')
|
|
op_dict = {
|
|
'placeholder': 'plchdr',
|
|
'get_attr': 'gt_prm',
|
|
'call_function': 'cl_fun',
|
|
'call_module': 'cl_mod',
|
|
'call_method': 'cl_meth',
|
|
}
|
|
|
|
max_lens = {}
|
|
col_names = ("name", "op", "target", "args", "kwargs")
|
|
for s in col_names:
|
|
max_lens[s] = len(s)
|
|
|
|
results = []
|
|
for n in g.nodes:
|
|
|
|
# activation_post_process_0 -> obs_0
|
|
name = str(n.name)
|
|
if shorten:
|
|
name = name.replace("activation_post_process", "obs")
|
|
|
|
op = str(n.op)
|
|
# placeholder -> plchdr, and so on
|
|
if shorten and op in op_dict:
|
|
op = op_dict[op]
|
|
|
|
target = str(n.target)
|
|
# <built-in function foo> -> <bi_fun foo>, and so on
|
|
if shorten:
|
|
built_in_func = built_in_func_re.search(target)
|
|
if built_in_func:
|
|
target = f"<bi_fun {built_in_func.group(1)}>"
|
|
built_in_meth = built_in_meth_re.search(target)
|
|
if built_in_meth:
|
|
target = f"<bi_meth {built_in_meth.group(1)}>"
|
|
target = target.replace("activation_post_process", "obs")
|
|
|
|
args = str(n.args)
|
|
if shorten:
|
|
args = args.replace("activation_post_process", "obs")
|
|
|
|
kwargs = str(n.kwargs)
|
|
|
|
# calculate maximum length of each column, so we can tabulate properly
|
|
for k, v in zip(col_names, (name, op, target, args, kwargs)):
|
|
max_lens[k] = max(max_lens[k], len(v))
|
|
results.append([name, op, target, args, kwargs])
|
|
|
|
res_str = ""
|
|
format_str = "{:<{name}} {:<{op}} {:<{target}} {:<{args}} {:<{kwargs}}\n"
|
|
res_str += format_str.format(*col_names, **max_lens)
|
|
for result in results:
|
|
res_str += format_str.format(*result, **max_lens)
|
|
|
|
# print an exra note on abbreviations which change attribute names,
|
|
# since users will have to un-abbreviate for further debugging
|
|
if shorten:
|
|
res_str += "*obs_{n} = activation_post_process_{n}\n"
|
|
return res_str
|
|
|
|
def get_per_tensor_qparams(activation_post_process):
|
|
assert is_per_tensor(activation_post_process.qscheme), 'Only per tensor quantization is supported'
|
|
scale, zero_point = activation_post_process.calculate_qparams()
|
|
scale = float(scale)
|
|
zero_point = int(zero_point)
|
|
dtype = activation_post_process.dtype
|
|
return scale, zero_point, dtype
|
|
|
|
def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[str, Union[Callable, str], Dict[str, Any]]]:
|
|
''' Given an activation_post_process module,
|
|
return node_type(e.g. call_function), quantize op(e.g. quantize_per_tensor) and a dictionary
|
|
of extracted qparams from the module
|
|
'''
|
|
dtype = activation_post_process.dtype # type: ignore[attr-defined]
|
|
compute_dtype = None
|
|
if hasattr(activation_post_process, "compute_dtype"):
|
|
compute_dtype = activation_post_process.compute_dtype # type: ignore[attr-defined]
|
|
quantize_op : Optional[Union[Callable, str]] = None
|
|
if dtype in [torch.quint8, torch.qint8]:
|
|
node_type = "call_function"
|
|
scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined]
|
|
if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined]
|
|
ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined]
|
|
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype}
|
|
quantize_op = torch.quantize_per_channel
|
|
else:
|
|
scale = float(scale)
|
|
zero_point = int(zero_point)
|
|
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
|
|
quantize_op = torch.quantize_per_tensor
|
|
elif dtype == torch.float16:
|
|
node_type = "call_method"
|
|
quantize_op = "to"
|
|
qparams = {"_dtype_": dtype}
|
|
elif dtype == torch.float32 and compute_dtype in [torch.quint8, torch.qint8, torch.float16]:
|
|
# dynamic quantization
|
|
node_type = "call_function"
|
|
quantize_op = torch.quantize_per_tensor_dynamic
|
|
# TODO: get reduce range from observer
|
|
# reduce_range = activation_post_process.reduce_range
|
|
reduce_range = torch.backends.quantized.engine == "fbgemm"
|
|
qparams = {"_dtype_": compute_dtype, "_reduce_range_": reduce_range}
|
|
else:
|
|
warnings.warn(f"Unsupported activation_post_process in get_quantize_node_info: {activation_post_process}")
|
|
return None
|
|
return node_type, quantize_op, qparams
|
|
|
|
def quantize_node(
|
|
in_node: Node,
|
|
obs_module: torch.nn.Module,
|
|
obs_node: Node,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
is_input: bool,
|
|
output_prefix: str = "_output") -> Node:
|
|
''' Add quantization nodes (eg. quantize_per_tensor/per_channel) for given node to graph
|
|
with the qparams calculated from activation_post_process (obs_module).
|
|
The observer node (obs_node) is used to find the FQN of the user of act_post_process.
|
|
e.g. Given input `node` in `node = self.conv(x)`, insert node:
|
|
`quantized_node = torch.quantize_per_tensor(x, self._scale_0, self._zer_point_0, self._dtype_0)`
|
|
where self._scale_0, self._zero_point_0 and self._dtype_0 are
|
|
calculated from `obs_module`
|
|
'''
|
|
# Find the first use of the observer node, we use this to get the scope of the module.
|
|
if is_input:
|
|
# if the quantize function is at the input of op, then we find the first user of the observer_node
|
|
# to get the path. If a linear call_function is in the user list, we return the first instance
|
|
# of linear node to get the FQN.
|
|
users = list(obs_node.users)
|
|
first_linear_use_or_first_use = users[0] if users else None
|
|
linear_node = None
|
|
for n in users:
|
|
if n.op == "call_function" and n.target == torch.nn.functional.linear:
|
|
linear_node = n
|
|
break
|
|
if linear_node:
|
|
first_linear_use_or_first_use = linear_node
|
|
prefix = "_input"
|
|
else:
|
|
# if the quantize function is at the output of the op, we use the observer input node to get the path
|
|
first_linear_use_or_first_use = in_node
|
|
prefix = output_prefix
|
|
|
|
if first_linear_use_or_first_use and first_linear_use_or_first_use.name in node_name_to_scope:
|
|
module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name]
|
|
else:
|
|
# TODO: it's not used, so actually we can skip quantization
|
|
# but this requires changing return type of quantize_node
|
|
# we can fix it later if needed
|
|
module_path = ""
|
|
root_module = modules['']
|
|
graph = quantized_graph
|
|
maybe_quantize_node_info = get_quantize_node_info(obs_module)
|
|
assert maybe_quantize_node_info is not None, \
|
|
f"Expecting quantize node info not to be None, observer: {obs_module}"
|
|
node_type, quantize_op, qparams = maybe_quantize_node_info
|
|
inputs = [in_node]
|
|
|
|
for key, value in qparams.items():
|
|
if key in ['_scale_', '_zero_point_']:
|
|
# For scale and zero_point values we register them as buffers in the root module.
|
|
qparam_node = create_getattr_from_value(root_module, graph, module_path + prefix + key, value)
|
|
inputs.append(qparam_node)
|
|
else:
|
|
# for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
|
|
inputs.append(value)
|
|
return graph.create_node(node_type, quantize_op, tuple(inputs), {})
|
|
|
|
def get_custom_module_class_keys(custom_module_mapping: Dict[QuantType, Dict[Type, Type]]) -> List[Any]:
|
|
r""" Get all the unique custom module keys in the custom config dict
|
|
e.g.
|
|
Input:
|
|
{
|
|
QuantType.STATIC: {
|
|
CustomModule1: ObservedCustomModule
|
|
},
|
|
QuantType.DYNAMIC: {
|
|
CustomModule2: DynamicObservedCustomModule
|
|
},
|
|
QuantType.WEIGHT_ONLY: {
|
|
CustomModule3: WeightOnlyObservedCustomModule
|
|
},
|
|
}
|
|
|
|
Output:
|
|
# extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts
|
|
[CustomModule1, CustomModule2, CustomModule3]
|
|
"""
|
|
# using set to dedup
|
|
float_custom_module_classes : Set[Any] = set()
|
|
for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]:
|
|
quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {})
|
|
quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys())
|
|
float_custom_module_classes |= quant_mode_custom_module_classes
|
|
return list(float_custom_module_classes)
|
|
|
|
def get_linear_prepack_op_for_dtype(dtype):
|
|
if dtype == torch.float16:
|
|
return torch.ops.quantized.linear_prepack_fp16
|
|
elif dtype == torch.qint8:
|
|
return torch.ops.quantized.linear_prepack
|
|
else:
|
|
raise Exception("can't get linear prepack op for dtype:", dtype)
|
|
|
|
def get_qconv_prepack_op(conv_op: Callable) -> Callable:
|
|
prepack_ops = {
|
|
torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack,
|
|
torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack,
|
|
torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack
|
|
}
|
|
prepack_op = prepack_ops.get(conv_op, None)
|
|
assert prepack_op, "Didn't find prepack op for {}".format(conv_op)
|
|
return prepack_op
|
|
|
|
def get_qconv_op(conv_op: Callable, has_relu: bool) -> Callable:
|
|
qconv_op = {
|
|
# has relu
|
|
True: {
|
|
torch.nn.functional.conv1d: torch.ops.quantized.conv1d_relu,
|
|
torch.nn.functional.conv2d: torch.ops.quantized.conv2d_relu,
|
|
torch.nn.functional.conv3d: torch.ops.quantized.conv3d_relu
|
|
},
|
|
False: {
|
|
torch.nn.functional.conv1d: torch.ops.quantized.conv1d,
|
|
torch.nn.functional.conv2d: torch.ops.quantized.conv2d,
|
|
torch.nn.functional.conv3d: torch.ops.quantized.conv3d
|
|
}
|
|
}
|
|
qconv = qconv_op[has_relu].get(conv_op)
|
|
assert qconv, "Can't find corresponding quantized conv op for {} {}".format(conv_op, has_relu)
|
|
return qconv
|
|
|
|
# Returns a function that can get a new attribute name for module with given
|
|
# prefix, for example,
|
|
# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
|
|
# >> new_name = get_new_observer_name(module)
|
|
# new_name will be an unused attribute name on module, e.g. `_observer_1`
|
|
def get_new_attr_name_with_prefix(prefix: str) -> Callable:
|
|
prefix = prefix.replace(".", "_")
|
|
|
|
def get_new_attr_name(module: torch.nn.Module):
|
|
def get_attr_name(i: int):
|
|
return prefix + str(i)
|
|
i = 0
|
|
attr_name = get_attr_name(i)
|
|
while hasattr(module, attr_name):
|
|
i += 1
|
|
attr_name = get_attr_name(i)
|
|
return attr_name
|
|
return get_new_attr_name
|
|
|
|
def collect_producer_nodes(node: Node) -> Optional[List[Node]]:
|
|
r''' Starting from a target node, trace back until we hit inpu or
|
|
getattr node. This is used to extract the chain of operators
|
|
starting from getattr to the target node, for example
|
|
def forward(self, x):
|
|
observed = self.observer(self.weight)
|
|
return F.linear(x, observed)
|
|
collect_producer_nodes(observed) will either return a list of nodes that
|
|
produces the observed node or None if we can't extract a self contained
|
|
graph without free variables(inputs of the forward function).
|
|
'''
|
|
nodes = [node]
|
|
frontier = [node]
|
|
while frontier:
|
|
node = frontier.pop()
|
|
all_args = list(node.args) + list(node.kwargs.values())
|
|
for arg in all_args:
|
|
if not isinstance(arg, Node):
|
|
continue
|
|
if arg.op == 'placeholder':
|
|
# hit input, can't fold in this case
|
|
return None
|
|
nodes.append(arg)
|
|
if not (arg.op == 'call_function' and arg.target == getattr):
|
|
frontier.append(arg)
|
|
return nodes
|
|
|
|
def graph_module_from_producer_nodes(
|
|
root: GraphModule, producer_nodes: List[Node]) -> GraphModule:
|
|
r''' Construct a graph module from extracted producer nodes
|
|
from `collect_producer_nodes` function
|
|
Args:
|
|
root: the root module for the original graph
|
|
producer_nodes: a list of nodes we use to construct the graph
|
|
Return:
|
|
A graph module constructed from the producer nodes
|
|
'''
|
|
assert len(producer_nodes) > 0, 'list of producer nodes can not be empty'
|
|
# since we traced back from node to getattrr
|
|
producer_nodes.reverse()
|
|
graph = Graph()
|
|
env: Dict[Any, Any] = {}
|
|
|
|
def load_arg(a):
|
|
return map_arg(a, lambda node: env[node])
|
|
for producer_node in producer_nodes:
|
|
env[producer_node] = graph.node_copy(producer_node, load_arg)
|
|
graph.output(load_arg(producer_nodes[-1]))
|
|
graph_module = GraphModule(root, graph)
|
|
return graph_module
|
|
|
|
def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
|
|
"""
|
|
Returns the unique device for a module, or None if no device is found.
|
|
Throws an error if multiple devices are detected.
|
|
"""
|
|
devices = {p.device for p in module.parameters()} | \
|
|
{p.device for p in module.buffers()}
|
|
assert len(devices) <= 1, (
|
|
"prepare only works with cpu or single-device CUDA modules, "
|
|
"but got devices {}".format(devices)
|
|
)
|
|
device = next(iter(devices)) if len(devices) > 0 else None
|
|
return device
|
|
|
|
def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str, value: Any) -> Node:
|
|
"""
|
|
Given a value of any type, creates a getattr node corresponding to the value and
|
|
registers the value as a buffer to the module.
|
|
"""
|
|
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
|
|
attr_name = get_new_attr_name(module)
|
|
device = assert_and_get_unique_device(module)
|
|
module.register_buffer(attr_name, torch.tensor(value, device=device))
|
|
# Create get_attr with value
|
|
attr_node = graph.create_node("get_attr", attr_name)
|
|
return attr_node
|
|
|
|
def create_qparam_nodes(
|
|
node_name: str,
|
|
scale: Any,
|
|
zero_point: Any,
|
|
modules: Dict[str, torch.nn.Module],
|
|
quantized_graph: Graph,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]]
|
|
) -> Tuple[Node, Node]:
|
|
"""
|
|
Create getattr nodes in the quantized graph for scale and zero point values.
|
|
The nodes are registered with the root_module of the model.
|
|
"""
|
|
root_module = modules['']
|
|
module_path, _ = node_name_to_scope[node_name]
|
|
scale_node = create_getattr_from_value(root_module, quantized_graph, (module_path + "_scale_"), scale)
|
|
zero_point_node = create_getattr_from_value(root_module, quantized_graph, (module_path + "_zero_point_"), zero_point)
|
|
return (scale_node, zero_point_node)
|
|
|
|
|
|
def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module], cache: Dict[Node, bool]) -> bool:
|
|
"""
|
|
If we know for sure that all of this node's args have no
|
|
tensors (are primitives), return True. If we either
|
|
find a tensor or are not sure, return False. Note: this
|
|
function is not exact.
|
|
"""
|
|
if cache and node in cache:
|
|
return cache[node]
|
|
|
|
result = False # will be overwritten
|
|
if not isinstance(node, Node):
|
|
result = True
|
|
elif node.op == 'placeholder':
|
|
result = False
|
|
elif node.op == 'call_module':
|
|
assert isinstance(node.target, str)
|
|
if is_activation_post_process(modules[node.target]):
|
|
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
|
|
elif node.op == 'call_module':
|
|
result = False
|
|
elif node.op == 'call_function' and node.target is operator.getitem:
|
|
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
|
|
elif node.op == 'get_attr':
|
|
result = False
|
|
elif node.target is getattr and node.args[1] in ['ndim', 'shape']:
|
|
# x1 = x0.ndim
|
|
result = True
|
|
elif node.op == 'call_method' and node.target == 'size':
|
|
# x1 = x0.size(0)
|
|
result = True
|
|
else:
|
|
found_one_tensor = False
|
|
for arg in node.args:
|
|
if isinstance(arg, list):
|
|
for list_el in arg:
|
|
if isinstance(list_el, Node):
|
|
this_list_el_args_have_no_tensors = \
|
|
all_node_args_have_no_tensors(list_el, modules, cache)
|
|
found_one_tensor = found_one_tensor or \
|
|
(not this_list_el_args_have_no_tensors)
|
|
# If found_one_tensor is True, there is no point in
|
|
# recursing further as the end result will always
|
|
# be True.
|
|
# TODO(future PR): remove this entire function and
|
|
# change to dtype inference without recursion.
|
|
if found_one_tensor:
|
|
result = not found_one_tensor
|
|
if cache:
|
|
cache[node] = result
|
|
return result
|
|
elif isinstance(arg, int):
|
|
pass
|
|
else:
|
|
if isinstance(arg, Node):
|
|
this_arg_args_have_no_tensors = all_node_args_have_no_tensors(arg, modules, cache)
|
|
found_one_tensor = found_one_tensor or \
|
|
(not this_arg_args_have_no_tensors)
|
|
# If found_one_tensor is True, there is no point in
|
|
# recursing further as the end result will always
|
|
# be True.
|
|
# TODO(future PR): remove this entire function and
|
|
# change to dtype inference without recursion.
|
|
if found_one_tensor:
|
|
result = not found_one_tensor
|
|
if cache:
|
|
cache[node] = result
|
|
return result
|
|
else:
|
|
found_one_tensor = True
|
|
result = not found_one_tensor
|
|
if cache:
|
|
cache[node] = result
|
|
return result
|
|
|
|
def all_node_args_except_first(node: Node) -> List[int]:
|
|
"""
|
|
Returns all node arg indices after first
|
|
"""
|
|
return list(range(1, len(node.args)))
|
|
|
|
def return_arg_list(arg_indices: List[int]) -> Callable[[Node], List[int]]:
|
|
"""
|
|
Constructs a function that takes a node as arg and returns the arg_indices
|
|
that are valid for node.args
|
|
"""
|
|
def arg_indices_func(node: Node) -> List[int]:
|
|
return [i for i in arg_indices if i < len(node.args)]
|
|
return arg_indices_func
|
|
|
|
NodeInfo = namedtuple("NodeInfo", "op target")
|
|
|
|
# this dict identifies which indices of a node are non tensors
|
|
# so that they can be propagated correctly since inserting observers
|
|
# for them would cause errors
|
|
|
|
NON_OBSERVABLE_ARG_DICT: Dict[NodeInfo, Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]] = {
|
|
NodeInfo("call_method", "masked_fill") : {
|
|
torch.bool: return_arg_list([1]),
|
|
float: return_arg_list([2])
|
|
},
|
|
NodeInfo("call_method", "permute") : {
|
|
int: all_node_args_except_first
|
|
},
|
|
NodeInfo("call_method", "repeat") : {
|
|
int: all_node_args_except_first
|
|
},
|
|
NodeInfo("call_method", "reshape") : {
|
|
int: all_node_args_except_first
|
|
},
|
|
NodeInfo("call_method", "size") : {
|
|
int: return_arg_list([1])
|
|
},
|
|
NodeInfo("call_method", "transpose") : {
|
|
int: all_node_args_except_first
|
|
},
|
|
NodeInfo("call_method", torch.transpose) : {
|
|
int: all_node_args_except_first
|
|
},
|
|
NodeInfo("call_method", "unsqueeze") : {
|
|
int: return_arg_list([1])
|
|
},
|
|
NodeInfo("call_method", "unsqueeze_") : {
|
|
int: return_arg_list([1])
|
|
},
|
|
NodeInfo("call_method", torch.unsqueeze) : {
|
|
int: return_arg_list([1])
|
|
},
|
|
NodeInfo("call_method", "view") : {
|
|
int: all_node_args_except_first
|
|
},
|
|
}
|
|
|
|
EMPTY_ARG_DICT: Dict[Union[type, torch.dtype], Callable[[Node], List[int]]] = {}
|
|
|
|
def get_non_observable_arg_indexes_and_types(node: Node) -> Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]:
|
|
"""
|
|
Returns a dict with of non float tensor types as keys and values which correspond to a
|
|
function to retrieve the list (which takes the node as an argument)
|
|
"""
|
|
info = NodeInfo(node.op, node.target)
|
|
|
|
return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT)
|
|
|
|
def node_return_type_is_int(node: Node) -> bool:
|
|
"""
|
|
Returns true if this node results in an integer, even if some of the args
|
|
are Tensors.
|
|
"""
|
|
return node.op == 'call_method' and node.target == 'size'
|
|
|
|
|
|
def is_get_tensor_info_node(node: Node) -> bool:
|
|
""" Returns True if this node is a node that takes a Tensor as input and output some
|
|
meta information about the Tensor, e.g. shape, size etc.
|
|
"""
|
|
result: bool = \
|
|
node.op == "call_function" and node.target == getattr and node.args[1] == "shape" # type: ignore[assignment]
|
|
return result
|
|
|
|
def maybe_get_next_module(
|
|
node: Node,
|
|
modules: Dict[str, nn.Module],
|
|
target_module_type: Optional[Type[nn.Module]] = None,
|
|
target_functional_type: Any = None,
|
|
) -> Optional[Node]:
|
|
""" Gets the next module that matches what is needed in
|
|
is_target_module_type if it exists
|
|
|
|
Args:
|
|
node: The node whose users we want to look at
|
|
target_module_type: Module type that we want to check
|
|
target_functional_type: Functional type that we want to check
|
|
"""
|
|
|
|
for user, _ in node.users.items():
|
|
if user.op == 'call_module' and target_module_type is not None and \
|
|
isinstance(modules[str(user.target)], target_module_type):
|
|
return user
|
|
elif (user.op == 'call_function' and target_functional_type is not None and
|
|
user.target == target_functional_type):
|
|
return user
|
|
|
|
return None
|
|
|
|
def create_node_from_old_node_preserve_meta(
|
|
quantized_graph: Graph,
|
|
create_node_args: Tuple[Any, ...],
|
|
old_node: Node,
|
|
) -> Node:
|
|
"""
|
|
Creates `new_node` and copies the necessary metadata to it from `old_node`.
|
|
"""
|
|
new_node = quantized_graph.create_node(*create_node_args)
|
|
new_node.stack_trace = old_node.stack_trace
|
|
return new_node
|
|
|
|
def get_skipped_module_name_and_classes(
|
|
prepare_custom_config: PrepareCustomConfig,
|
|
is_standalone_module: bool) -> Tuple[List[str], List[Type[Any]]]:
|
|
skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names)
|
|
skipped_module_classes = copy.copy(prepare_custom_config.non_traceable_module_classes)
|
|
if not is_standalone_module:
|
|
# standalone module and custom module config are applied in top level module
|
|
skipped_module_names += list(prepare_custom_config.standalone_module_names.keys())
|
|
skipped_module_classes += list(prepare_custom_config.standalone_module_classes.keys())
|
|
skipped_module_classes += get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
|
|
|
|
return skipped_module_names, skipped_module_classes
|