mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61360 By default, NS graph matching matches from the end of the graph to the start. This PR reverses the returned results so that the outputs of the NS APIs are in the order of execution, making it easier to analyze. Test Plan: ``` python test/test_quantization.py TestFXGraphMatcher.test_results_order ``` Imported from OSS Reviewed By: hx89 Differential Revision: D29600348 fbshipit-source-id: c9fa4a3748db27c1788eebf803f35221e6fc8701
461 lines
19 KiB
Python
461 lines
19 KiB
Python
import collections
|
|
import enum
|
|
|
|
import torch
|
|
toq = torch.ops.quantized
|
|
|
|
from torch.fx import GraphModule
|
|
from torch.fx.graph import Graph, Node
|
|
|
|
from .utils import getattr_from_fqn
|
|
from .ns_types import NSSubgraph, NSNodeTargetType
|
|
from .mappings import (
|
|
get_base_name_to_sets_of_related_ops,
|
|
get_unmatchable_types_map,
|
|
)
|
|
from .pattern_utils import (
|
|
get_type_a_related_to_b,
|
|
get_reversed_fusions,
|
|
end_node_matches_reversed_fusion,
|
|
)
|
|
from torch.quantization import (
|
|
ObserverBase,
|
|
FakeQuantizeBase,
|
|
)
|
|
|
|
from typing import Dict, Tuple, List, Optional, Set, Any
|
|
|
|
def _get_output_nodes(g: Graph) -> List[Node]:
|
|
return [n for n in g.nodes if n.op == 'output']
|
|
|
|
class _NSGraphMatchableSubgraphsIterator:
|
|
"""
|
|
Iterates through the graph of gm, starting with the output nodes
|
|
and continuing backwards.
|
|
1. Returns matchable subgraphs, in order. A subgraph is defined by
|
|
(start_node, end_node).
|
|
2. Skips over non-matchable subgraphs
|
|
"""
|
|
def __init__(
|
|
self,
|
|
gm: GraphModule,
|
|
non_matchable_functions: Set[NSNodeTargetType],
|
|
non_matchable_modules: Set[NSNodeTargetType],
|
|
non_matchable_methods: Set[NSNodeTargetType],
|
|
):
|
|
self.gm: GraphModule = gm
|
|
self.non_matchable_functions: Set[NSNodeTargetType] = non_matchable_functions
|
|
self.non_matchable_modules: Set[NSNodeTargetType] = non_matchable_modules
|
|
self.non_matchable_methods: Set[NSNodeTargetType] = non_matchable_methods
|
|
self.seen_nodes: Set[Node] = set()
|
|
self.stack: List[Node] = []
|
|
for start_node in _get_output_nodes(self.gm.graph):
|
|
self.stack.append(start_node)
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self) -> NSSubgraph:
|
|
"""
|
|
Returns the next matchable subgraph.
|
|
"""
|
|
while len(self.stack) > 0:
|
|
cur_end_node = self.stack.pop()
|
|
if cur_end_node in self.seen_nodes:
|
|
continue
|
|
|
|
# for subgraphs which are single nodes, start_node == end_node
|
|
# for subgraphs with more than one node, start node != end_node
|
|
cur_start_node = cur_end_node
|
|
# Subgraphs like linear-relu have the base node as the start node.
|
|
# Subgraphs like dequantize-linear-relu-to(torch.float16) have the
|
|
# base node as the second node.
|
|
# The cur_base_op_node var will move to the actual node during
|
|
# the fusion matching later in this code block.
|
|
cur_base_op_node = cur_end_node
|
|
|
|
# Check for potential fusions. For now, we are greedy
|
|
# and always skip all non-base nodes of a fusion. For example,
|
|
# if we match linear-relu backwards, we will always skip the
|
|
# relu node and attempt to match the linear node. This can
|
|
# be made configurable later if needed.
|
|
for _reverse_fusion_ops, base_op_idx in get_reversed_fusions():
|
|
is_match = end_node_matches_reversed_fusion(
|
|
cur_end_node, _reverse_fusion_ops, self.gm)
|
|
if is_match:
|
|
# navigate to the base node
|
|
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
|
|
self.seen_nodes.add(cur_start_node)
|
|
# for now, assume that there are no other nodes
|
|
# which need to be added to the stack
|
|
cur_start_node = cur_start_node.args[0] # type: ignore[assignment]
|
|
# if the base op index matches the current node, set it
|
|
rev_base_op_idx = \
|
|
len(_reverse_fusion_ops) - 2 - base_op_idx
|
|
if rev_fusion_idx == rev_base_op_idx:
|
|
cur_base_op_node = cur_start_node
|
|
break
|
|
|
|
self.seen_nodes.add(cur_start_node)
|
|
# add args of previous nodes to stack
|
|
for arg in cur_start_node.all_input_nodes:
|
|
self._recursively_add_node_arg_to_stack(arg)
|
|
|
|
# skip unmatchable nodes
|
|
# note: this check is done on the start_node, i.e.
|
|
# if we are matching linear-relu in reverse, this would do the matchable
|
|
# check on the linear
|
|
if not self._is_matchable(cur_base_op_node):
|
|
continue
|
|
|
|
# If an observer or a fake_quant was not matched as a part of
|
|
# a pattern of multiple nodes, ignore it. One case where this is
|
|
# relevant is an observer on a graph input, which was added because
|
|
# it is necessary for the next node.
|
|
if cur_end_node.op == 'call_module' and cur_start_node is cur_end_node:
|
|
maybe_obs = getattr_from_fqn(self.gm, cur_end_node.target) # type: ignore[arg-type]
|
|
if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)):
|
|
continue
|
|
|
|
return NSSubgraph(
|
|
start_node=cur_start_node, end_node=cur_end_node,
|
|
base_op_node=cur_base_op_node)
|
|
|
|
raise StopIteration
|
|
|
|
def _recursively_add_node_arg_to_stack(self, arg: Any) -> None:
|
|
"""
|
|
Adds all of the nodes in this arg to the stack, properly navigating
|
|
through list, dicts and tuples.
|
|
"""
|
|
if isinstance(arg, Node):
|
|
self.stack.append(arg)
|
|
elif isinstance(arg, torch.fx.immutable_collections.immutable_list) or type(arg) is tuple:
|
|
for inner_arg in arg:
|
|
self._recursively_add_node_arg_to_stack(inner_arg)
|
|
elif isinstance(arg, torch.fx.immutable_collections.immutable_dict):
|
|
for key, value in arg.items():
|
|
self._recursively_add_node_arg_to_stack(value)
|
|
|
|
def _is_matchable(self, node: Node) -> bool:
|
|
if node.op == 'call_function':
|
|
return not (node.target in self.non_matchable_functions)
|
|
elif node.op == 'call_module':
|
|
assert isinstance(node.target, str)
|
|
target_mod = getattr_from_fqn(self.gm, node.target)
|
|
return not \
|
|
any(isinstance(target_mod, t) # type: ignore[arg-type]
|
|
for t in self.non_matchable_modules)
|
|
elif node.op == 'call_method':
|
|
return not (node.target in self.non_matchable_methods)
|
|
else:
|
|
return False
|
|
|
|
class GraphMatchingException(Exception):
|
|
"""
|
|
Exception raised when two graphs cannot be matched.
|
|
"""
|
|
pass
|
|
|
|
class SubgraphTypeRelationship(enum.Enum):
|
|
# same type, known
|
|
# example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d
|
|
EQUAL = enum.auto()
|
|
# same type, but the type is not known to Numerical Suite
|
|
# (user defined type, etc).
|
|
EQUAL_BUT_UKNOWN = enum.auto()
|
|
# known, same subgraph_relationship set, but not the same type
|
|
# example: F.linear and toq.linear
|
|
RELATED_BUT_NOT_EQUAL = enum.auto()
|
|
# not related
|
|
NOT_RELATED = enum.auto()
|
|
|
|
def _get_subgraph_relationship_type(
|
|
subgraph_a: NSSubgraph,
|
|
subgraph_b: NSSubgraph,
|
|
gm_a: GraphModule,
|
|
gm_b: GraphModule,
|
|
type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]],
|
|
) -> SubgraphTypeRelationship:
|
|
node_a = subgraph_a.base_op_node
|
|
node_b = subgraph_b.base_op_node
|
|
|
|
# TODO(next): make this code handle matching by what is before the base op
|
|
if node_a.op != node_b.op:
|
|
if not (
|
|
node_a.op in ('call_function', 'call_method') and
|
|
node_b.op in ('call_function', 'call_method')
|
|
):
|
|
return SubgraphTypeRelationship.NOT_RELATED
|
|
|
|
if node_a.op in ('call_function', 'call_method'):
|
|
key = (node_a.target, node_b.target)
|
|
|
|
if key not in type_a_related_to_b:
|
|
if node_a.target == node_b.target:
|
|
return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
|
|
else:
|
|
return SubgraphTypeRelationship.NOT_RELATED
|
|
# after this point, we are dealing with known types
|
|
|
|
if node_a.target == node_b.target:
|
|
node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node
|
|
node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node
|
|
if node_a_has_prev and (not node_b_has_prev):
|
|
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
|
elif (not node_a_has_prev) and node_b_has_prev:
|
|
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
|
elif (not node_a_has_prev) and (not node_b_has_prev):
|
|
return SubgraphTypeRelationship.EQUAL
|
|
else:
|
|
# TODO(future PR): check for matches start_op_node and base_op_node
|
|
return SubgraphTypeRelationship.EQUAL
|
|
|
|
if key in type_a_related_to_b:
|
|
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
|
else:
|
|
return SubgraphTypeRelationship.NOT_RELATED
|
|
elif node_a.op == 'call_module':
|
|
assert (subgraph_a.base_op_node == subgraph_a.start_node and
|
|
subgraph_b.base_op_node == subgraph_b.start_node), \
|
|
"Matching call_module patterns where base_op_node != start_node is not supported yet"
|
|
# for call_module, we need to look up the modules to do the type check
|
|
assert isinstance(node_a.target, str)
|
|
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
|
assert isinstance(node_b.target, str)
|
|
mod_b = getattr_from_fqn(gm_b, node_b.target)
|
|
|
|
key = (type(mod_a), type(mod_b))
|
|
|
|
if key not in type_a_related_to_b:
|
|
if type(mod_a) == type(mod_b):
|
|
return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
|
|
else:
|
|
return SubgraphTypeRelationship.NOT_RELATED
|
|
elif type(mod_a) == type(mod_b):
|
|
return SubgraphTypeRelationship.EQUAL
|
|
else:
|
|
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
|
|
|
return SubgraphTypeRelationship.NOT_RELATED
|
|
|
|
def _get_name_for_subgraph(
|
|
subgraph_a: NSSubgraph,
|
|
gm_a: GraphModule,
|
|
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
|
|
existing_names: Set[str],
|
|
) -> str:
|
|
"""
|
|
Returns a unique name for a subgraph. This name is based on two things:
|
|
1. the name of the set containing the underlying type of the base op in the
|
|
subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op)
|
|
2. the number of previous subgraphs with related underlying type of the base op
|
|
|
|
For example, in the graph
|
|
|
|
linear0 -> relu0 -> linear1 -> relu1
|
|
|
|
The subgraphs are (linear0, relu0) and (linear1, relu1). If we iterate
|
|
from the output node backwards, the name given to (linear1, relu1) will be
|
|
`base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0)
|
|
will be `base_op_torch.nn.functional.linear_1`.
|
|
|
|
Why are we not just using the node name? Answer: because of two requirements:
|
|
A. fusions must be supported
|
|
B. some Numeric Suite APIs can be called without having all of the models in memory
|
|
|
|
For example, let's say we need to match nodes of
|
|
|
|
(1) ... -> linear0 -> relu0 -> ...
|
|
|
|
And
|
|
|
|
(2) ... -> linear_relu0 -> ...
|
|
|
|
Without being able to inspect them together. With the current naming scheme, if
|
|
we iterate through both of these graphs in the same order, and assuming the rest
|
|
of the graphs match, both of these subgraphs will get the same name without
|
|
(1) and (2) knowing anything about each other.
|
|
"""
|
|
target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a)
|
|
target_base_type = None
|
|
for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items():
|
|
if target_type in sets_of_related_ops:
|
|
target_base_type = base_name
|
|
target_base_name = 'base_op_' + str(target_base_type)
|
|
counter = 0
|
|
proposed_name = target_base_name + '_' + str(counter)
|
|
while proposed_name in existing_names:
|
|
counter += 1
|
|
proposed_name = target_base_name + '_' + str(counter)
|
|
existing_names.add(proposed_name)
|
|
return proposed_name
|
|
|
|
def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetType]:
|
|
if node.op in ('call_function', 'call_method'):
|
|
return node.target
|
|
elif node.op == 'call_module':
|
|
assert isinstance(node.target, str)
|
|
mod = getattr_from_fqn(gm, node.target)
|
|
return type(mod)
|
|
return None
|
|
|
|
def get_matching_subgraph_pairs(
|
|
gm_a: GraphModule,
|
|
gm_b: GraphModule,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> Dict[str, Tuple[NSSubgraph, NSSubgraph]]:
|
|
"""
|
|
Matches matchable subgraphs of graph_a to graph_b.
|
|
|
|
For a node, "matchable" is defined as a node which is not an observer,
|
|
fake_quants, quant or dequant.
|
|
|
|
A subgraph can contain one or more nodes. A subgraph is matchable if
|
|
at least one node inside of it is matchable. Currently, all nodes in
|
|
a subgraph must be matchable (because we assume no observers will be
|
|
inserted in the middle of a fusion).
|
|
|
|
A subgraph is defined by (start_node, end_node). We assume that only
|
|
start_node and end_node are linked with the surrounding graph, all other
|
|
nodes in a subgraph are self-contained.
|
|
|
|
A pair of nodes is "related" if both nodes represent the same mathematical
|
|
operation across different quantization flavors. For example,
|
|
`F.linear` and `torch.ops.quantized.linear` are related, and
|
|
`F.linear` and `torch.nn.Conv` are not related.
|
|
|
|
For each matchable pair of nodes node_a and node_b, they will match
|
|
if node_a and node_b are related.
|
|
|
|
For graphs A and B, they will match iff:
|
|
1. the number of matchable subgraphs in A and B is equivalent
|
|
2. when iterating through the matchable subgraphs of A and B in the same order, each
|
|
corresponding pair of base nodes is related.
|
|
|
|
This enables us to find the corresponding subgraphs between
|
|
graphs of related models. For example, if we had two graphs such as:
|
|
|
|
graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1
|
|
w -/
|
|
b -/
|
|
|
|
graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1
|
|
packed_params_0 -/
|
|
|
|
This function will return the following result:
|
|
{
|
|
'conv_0': ( # the name of the node in graph_b
|
|
(conv_0, conv_0), # (start_node_a, end_node_a)
|
|
(qconv_0, qconv_0), # (start_node_b, end_node_b)
|
|
),
|
|
}
|
|
|
|
Or, if we have a fusion pattern,
|
|
|
|
graph_a: x0 -> linear_0 -> relu_0 -> obs_0 -> x1
|
|
w -/
|
|
b -/
|
|
|
|
graph_b: x0 -> quant_0 -> linear_relu_0 -> dequant_0 -> x1
|
|
packed_params_0 -/
|
|
|
|
This function will return the following result:
|
|
{
|
|
'linear_relu_0': ( # the name of the node in graph_b
|
|
(linear_0, relu_0), # (start_node_a, end_node_a)
|
|
(linear_relu_0, linear_relu_0), # (start_node_b, end_node_b)
|
|
),
|
|
}
|
|
"""
|
|
if unmatchable_types_map is None:
|
|
unmatchable_types_map = get_unmatchable_types_map()
|
|
non_matchable_functions = unmatchable_types_map['funs_unmatchable']
|
|
non_matchable_modules = unmatchable_types_map['mods_unmatchable']
|
|
non_matchable_methods = unmatchable_types_map['meths_unmatchable']
|
|
|
|
graph_a_iterator = _NSGraphMatchableSubgraphsIterator(
|
|
gm_a, non_matchable_functions, non_matchable_modules,
|
|
non_matchable_methods)
|
|
graph_b_iterator = _NSGraphMatchableSubgraphsIterator(
|
|
gm_b, non_matchable_functions, non_matchable_modules,
|
|
non_matchable_methods)
|
|
results = collections.OrderedDict()
|
|
if base_name_to_sets_of_related_ops is None:
|
|
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
|
|
type_a_related_to_b = \
|
|
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
|
|
|
existing_names_a: Set[str] = set()
|
|
existing_names_b: Set[str] = set()
|
|
|
|
while True:
|
|
# fetch the next subgraphs from a and b
|
|
cur_subgraph_a, cur_subgraph_b = None, None
|
|
try:
|
|
cur_subgraph_a = next(graph_a_iterator)
|
|
except StopIteration:
|
|
pass
|
|
try:
|
|
cur_subgraph_b = next(graph_b_iterator)
|
|
except StopIteration:
|
|
pass
|
|
|
|
# look up types of a and b for useful error messages
|
|
type_start_a, type_start_b = None, None
|
|
if cur_subgraph_a is not None:
|
|
type_start_a = _get_node_target_type(cur_subgraph_a.start_node, gm_a)
|
|
if cur_subgraph_b is not None:
|
|
type_start_b = _get_node_target_type(cur_subgraph_b.start_node, gm_b)
|
|
|
|
# check for results and determine what to do next
|
|
if cur_subgraph_a is not None and cur_subgraph_b is not None:
|
|
# both nodes were fetched, check for subgraph_relationship
|
|
# note: subgraph_relationship is checked on the start node, i.e.
|
|
# if a linear-relu pattern is checked, we would check for subgraph_relationship
|
|
# of the linear
|
|
subgraph_relationship = _get_subgraph_relationship_type(
|
|
cur_subgraph_a, cur_subgraph_b,
|
|
gm_a, gm_b, type_a_related_to_b)
|
|
if subgraph_relationship == SubgraphTypeRelationship.NOT_RELATED:
|
|
msg = f"""
|
|
The subgraphs
|
|
({cur_subgraph_a}, {type_start_a}) and
|
|
({cur_subgraph_b}, {type_start_b})
|
|
are not related. Please ensure that the two models you pass in have the same number
|
|
of subgraphs, and each pair of subgraphs is related to each other."""
|
|
raise GraphMatchingException(msg)
|
|
elif subgraph_relationship == SubgraphTypeRelationship.EQUAL_BUT_UKNOWN:
|
|
# skip matching but unknown types
|
|
continue
|
|
key_name_a = _get_name_for_subgraph(
|
|
cur_subgraph_a, gm_a, base_name_to_sets_of_related_ops,
|
|
existing_names_a)
|
|
key_name_b = _get_name_for_subgraph(
|
|
cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops,
|
|
existing_names_b)
|
|
assert key_name_a == key_name_b, \
|
|
f"Subgraph names {key_name_a} and {key_name_b} do not match"
|
|
results[key_name_a] = (cur_subgraph_a, cur_subgraph_b)
|
|
continue
|
|
elif cur_subgraph_a is None and cur_subgraph_b is None:
|
|
# we reached the end of both graphs
|
|
break
|
|
else:
|
|
# only one node was fetched, no match possible, throw error
|
|
msg = f"""
|
|
Attempting to match
|
|
({cur_subgraph_a}, {type_start_a}) and
|
|
({cur_subgraph_b}, {type_start_b}),
|
|
one of which is empty. Please ensure that the two models you pass in have the same number
|
|
of subgraphs."""
|
|
raise GraphMatchingException(msg)
|
|
|
|
# The subgraph pairs are originally created by traversing the two graphs
|
|
# from the outputs to the inputs. Reverse the results to return the
|
|
# subgraphs in their order of execution.
|
|
results = collections.OrderedDict(reversed(list(results.items())))
|
|
|
|
return results
|