ns for fx: add partial support for subgraphs with base_op_node (#54254)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54254

In fp16 emulation, we now have patterns such as

```
... -> dequantize -> linear -> relu -> to(torch.float16) -> ...
```

This PR adds support for
* specifying a subgraph's "base_op_node", which is the node with the op
which should be matched to related nodes. In the example above,
"base_op_node" would be the linear node, and it would be the second
node in the matched pattern.
* matching these fusion patterns and properly setting "base_op_node"
based on pattern and index
* using "base_op_node" instead of "start_node" throughout the NS
codebase wherever the intent is to match subgraphs or create names
for subgraphs.

At the end of this PR, matching unshadowed activations with an example
fp16 emulation pattern works e2e.

I'm saving the following work for future PRs (soon), mostly to keep
PR size manageable:
* adding weight matching (will require some changes to function which
extracts weights)
* adding shadowed activation matching (will require some changes to
shadow copying)
* adding input logging for these patterns (will likely require some changes as well)

Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_linear_fp16
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D27158199

fbshipit-source-id: 49fc445395452fda62e3c7a243544190f9af691c
This commit is contained in:
Vasiliy Kuznetsov 2021-03-25 22:27:30 -07:00 committed by Facebook GitHub Bot
parent 454832e5fa
commit 182d8c375c
4 changed files with 169 additions and 58 deletions

View File

@ -539,6 +539,31 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
results_len=2,
should_log_inputs=True)
@skipIfNoFBGEMM
def test_linear_fp16(self):
class M(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.Tensor(4, 4))
self.b1 = nn.Parameter(torch.zeros(4))
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
def forward(self, x):
x = F.linear(x, self.w1, self.b1)
x = F.relu(x)
return x
qconfig_dict = {'': torch.quantization.float16_static_qconfig}
m = M().eval()
expected_occurrence = {
ns.call_module(OutputLogger): 1,
}
self._test_match_activations(
m, (torch.randn(4, 4),),
prepared_expected_node_occurrence=expected_occurrence,
results_len=1,
qconfig_dict=qconfig_dict)
class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase):
"""

View File

@ -17,7 +17,7 @@ from torch.fx.graph import Graph, Node
from .utils import getattr_from_fqn
from .ns_types import NSSubgraph
from typing import Dict, Tuple, List, Optional, Set, Callable, Any
from typing import Dict, Tuple, List, Optional, Set, Callable, Any, Union
def _get_output_nodes(g: Graph) -> List[Node]:
return [n for n in g.nodes if n.op == 'output']
@ -108,54 +108,110 @@ def get_non_matchable_modules() -> Set[Callable]:
torch.quantization.FakeQuantizeBase,
])
def get_reversed_fusions() -> Set[Tuple[Callable, Callable]]:
NSFusionElType = Union[
Callable, # call_function or call_module type, example: F.linear or nn.Conv2d
str, # call_method name, example: "dequantize"
Tuple[str, Any], # call_method name and first argument, example: ("to", torch.float16)
]
NSFusionType = Union[
Tuple[NSFusionElType, NSFusionElType],
Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType],
]
def get_reversed_fusions() -> Set[Tuple[NSFusionType, int]]:
"""
Set of potential fusions, in reverse order. The order is reversed
to match how fusion patterns are defined in quantization code.
Fusion format:
((fusion_op_0, fusion_op_1), base_op_idx)
Where base_op_idx is the idx of the op we should use to match other related
ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx
of 0 represents the first op in regular (non-reverse) order, 1 represents the
second op, etc.
"""
# TODO(future PR): remove the custom syntax for defining fusion patterns
# and reuse either quantization's syntax or something else.
return set([
(F.relu, F.linear),
(nn.ReLU, nn.Conv2d),
((F.relu, F.linear), 0),
((nn.ReLU, nn.Conv2d), 0),
# linear-relu fp16 emulation:
# fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
((("to", torch.float16), F.relu, F.linear, "dequantize"), 1),
])
# TODO(future PR): we should see if we can reuse quantization's fusion
# patterns here.
def end_node_matches_reversed_fusion(
end_node: Node,
reversed_fusion: Tuple[Callable, Callable],
reversed_fusion: NSFusionType,
gm: GraphModule,
) -> bool:
"""
Returns true if a pattern ending with `end_node` matches
the fusion pattern.
"""
if end_node.op == 'call_function':
cur_node = end_node
for fusion_idx in range(len(reversed_fusion)):
cur_fusion_op = reversed_fusion[fusion_idx]
if cur_node.target != cur_fusion_op:
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
cur_node = end_node
for fusion_idx in range(len(reversed_fusion)):
cur_fusion_el = reversed_fusion[fusion_idx]
if cur_node.op == 'call_function':
fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and \
(not isinstance(cur_fusion_el, type))
if fusion_el_is_fun:
if cur_node.target != cur_fusion_el:
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
return True
elif end_node.op == 'call_module':
cur_node = end_node
for fusion_idx in range(len(reversed_fusion)):
cur_fusion_op = reversed_fusion[fusion_idx]
assert isinstance(cur_node.target, str)
target_mod = getattr_from_fqn(gm, cur_node.target)
if not isinstance(cur_fusion_op, type):
return False
if not isinstance(target_mod, cur_fusion_op):
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
elif cur_node.op == 'call_module':
fusion_el_is_mod = isinstance(cur_fusion_el, type)
if fusion_el_is_mod:
assert isinstance(cur_node.target, str)
target_mod = getattr_from_fqn(gm, cur_node.target)
if not isinstance(cur_fusion_el, type):
return False
if not isinstance(target_mod, cur_fusion_el):
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
return True
return False
elif cur_node.op == 'call_method':
fusion_el_is_meth_with_second_arg = \
isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2
fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str)
if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg:
if fusion_el_is_meth_without_args:
if cur_node.target != cur_fusion_el:
return False
else:
assert isinstance(cur_fusion_el, tuple)
if cur_node.target != cur_fusion_el[0]:
return False
elif len(cur_node.args) < 2:
return False
elif cur_node.args[1] != cur_fusion_el[1]:
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
else:
return False
return True
class _NSGraphMatchableSubgraphsIterator:
@ -195,22 +251,33 @@ class _NSGraphMatchableSubgraphsIterator:
# for subgraphs which are single nodes, start_node == end_node
# for subgraphs with more than one node, start node != end_node
cur_start_node = cur_end_node
# Subgraphs like linear-relu have the base node as the start node.
# Subgraphs like dequantize-linear-relu-to(torch.float16) have the
# base node as the second node.
# The cur_base_op_node var will move to the actual node during
# the fusion matching later in this code block.
cur_base_op_node = cur_end_node
# Check for potential fusions. For now, we are greedy
# and always skip all non-base nodes of a fusion. For example,
# if we match linear-relu backwards, we will always skip the
# relu node and attempt to match the linear node. This can
# be made configurable later if needed.
for _reverse_fusion_ops in get_reversed_fusions():
for _reverse_fusion_ops, base_op_idx in get_reversed_fusions():
is_match = end_node_matches_reversed_fusion(
cur_end_node, _reverse_fusion_ops, self.gm)
if is_match:
# navigate to the base node
for fusion_idx in range(len(_reverse_fusion_ops) - 1):
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
self.seen_nodes.add(cur_start_node)
# for now, assume that there are no other nodes
# which need to be added to the stack
cur_start_node = cur_start_node.args[0] # type: ignore
# if the base op index matches the current node, set it
rev_base_op_idx = \
len(_reverse_fusion_ops) - 2 - base_op_idx
if rev_fusion_idx == rev_base_op_idx:
cur_base_op_node = cur_start_node
break
self.seen_nodes.add(cur_start_node)
@ -223,10 +290,12 @@ class _NSGraphMatchableSubgraphsIterator:
# note: this check is done on the start_node, i.e.
# if we are matching linear-relu in reverse, this would do the matchable
# check on the linear
if not self._is_matchable(cur_start_node):
if not self._is_matchable(cur_base_op_node):
continue
return NSSubgraph(start_node=cur_start_node, end_node=cur_end_node)
return NSSubgraph(
start_node=cur_start_node, end_node=cur_end_node,
base_op_node=cur_base_op_node)
raise StopIteration
@ -263,38 +332,55 @@ class GraphMatchingException(Exception):
"""
pass
class NodeTypeRelationship(enum.Enum):
class SugraphTypeRelationship(enum.Enum):
# same type
# example: F.linear and toq.linear, or nn.Conv2d and nn.Conv2d
EQUAL = enum.auto()
# same node_relationship set, but not the same type
# same subgraph_relationship set, but not the same type
# example: F.linear and toq.linear
RELATED_BUT_NOT_EQUAL = enum.auto()
# not related
NOT_RELATED = enum.auto()
def _get_node_relationship_type(
node_a: Node,
node_b: Node,
def _get_subgraph_relationship_type(
subgraph_a: NSSubgraph,
subgraph_b: NSSubgraph,
gm_a: GraphModule,
gm_b: GraphModule,
type_a_related_to_b: Set[Tuple[Callable, Callable]],
) -> NodeTypeRelationship:
) -> SugraphTypeRelationship:
node_a = subgraph_a.base_op_node
node_b = subgraph_b.base_op_node
# TODO(next): make this code handle matching by what is before the base op
if node_a.op != node_b.op:
# for now, comparing call_module to call_function is not supported
# this can be added later if needed
return NodeTypeRelationship.NOT_RELATED
return SugraphTypeRelationship.NOT_RELATED
if node_a.op == 'call_function':
if node_a.target == node_b.target:
# nodes with equivalent targets always match (i.e. F.linear and F.linear)
return NodeTypeRelationship.EQUAL
node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node
node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node
if node_a_has_prev and (not node_b_has_prev):
return SugraphTypeRelationship.RELATED_BUT_NOT_EQUAL
elif (not node_a_has_prev) and node_b_has_prev:
return SugraphTypeRelationship.RELATED_BUT_NOT_EQUAL
elif (not node_a_has_prev) and (not node_b_has_prev):
return SugraphTypeRelationship.EQUAL
else:
# TODO(future PR): check for matches start_op_node and base_op_node
return SugraphTypeRelationship.EQUAL
key = (node_a.target, node_b.target)
if key in type_a_related_to_b:
return NodeTypeRelationship.RELATED_BUT_NOT_EQUAL
return SugraphTypeRelationship.RELATED_BUT_NOT_EQUAL
else:
return NodeTypeRelationship.NOT_RELATED
return SugraphTypeRelationship.NOT_RELATED
elif node_a.op == 'call_module':
assert (subgraph_a.base_op_node == subgraph_a.start_node and
subgraph_b.base_op_node == subgraph_b.start_node), \
"Matching call_module patterns where base_op_node != start_node is not supported yet"
# for call_module, we need to look up the modules to do the type check
assert isinstance(node_a.target, str)
mod_a = getattr_from_fqn(gm_a, node_a.target)
@ -302,13 +388,13 @@ def _get_node_relationship_type(
mod_b = getattr_from_fqn(gm_b, node_b.target)
# modules with equivalent types always match (i.e. nn.Conv2d and nn.Conv2d)
if type(mod_a) == type(mod_b):
return NodeTypeRelationship.EQUAL
return SugraphTypeRelationship.EQUAL
key = (type(mod_a), type(mod_b))
if key in type_a_related_to_b:
return NodeTypeRelationship.RELATED_BUT_NOT_EQUAL
return SugraphTypeRelationship.RELATED_BUT_NOT_EQUAL
else:
return NodeTypeRelationship.NOT_RELATED
return NodeTypeRelationship.NOT_RELATED
return SugraphTypeRelationship.NOT_RELATED
return SugraphTypeRelationship.NOT_RELATED
def _get_name_for_subgraph(
subgraph_a: NSSubgraph,
@ -348,7 +434,7 @@ def _get_name_for_subgraph(
of the graphs match, both of these subgraphs will get the same name without
(1) and (2) knowing anything about each other.
"""
target_type = _get_node_target_type(subgraph_a.start_node, gm_a)
target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a)
target_base_type = None
for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items():
if target_type in sets_of_related_ops:
@ -473,19 +559,19 @@ def get_matching_subgraph_pairs(
# check for results and determine what to do next
if cur_subgraph_a is not None and cur_subgraph_b is not None:
# both nodes were fetched, check for node_relationship
# note: node_relationship is checked on the start node, i.e.
# if a linear-relu pattern is checked, we would check for node_relationship
# both nodes were fetched, check for subgraph_relationship
# note: subgraph_relationship is checked on the start node, i.e.
# if a linear-relu pattern is checked, we would check for subgraph_relationship
# of the linear
node_relationship = _get_node_relationship_type(
cur_subgraph_a.start_node, cur_subgraph_b.start_node,
subgraph_relationship = _get_subgraph_relationship_type(
cur_subgraph_a, cur_subgraph_b,
gm_a, gm_b, type_a_related_to_b)
if node_relationship == NodeTypeRelationship.NOT_RELATED:
if subgraph_relationship == SugraphTypeRelationship.NOT_RELATED:
msg = f"""
({cur_subgraph_a}, {type_start_a}) and
({cur_subgraph_b}, {type_start_b}) are not related"""
raise GraphMatchingException(msg)
elif node_relationship == NodeTypeRelationship.EQUAL:
elif subgraph_relationship == SugraphTypeRelationship.EQUAL:
# For now, skip nodes with equal types. In the future, this can
# be made configurable.
continue

View File

@ -10,5 +10,5 @@ class NSSingleResultValuesType(str, enum.Enum):
NSSubgraph = NamedTuple(
'NSSubgraph',
[('start_node', Node), ('end_node', Node)]
[('start_node', Node), ('end_node', Node), ('base_op_node', Node)]
)

View File

@ -233,8 +233,8 @@ def _extract_weights_impl(
nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
for match_name, match in matched_subgraph_pairs.items():
subgraph_a, subgraph_b = match
nodes_and_names_to_instrument_a.append((subgraph_a.start_node, match_name))
nodes_and_names_to_instrument_b.append((subgraph_b.start_node, match_name))
nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
# populate the results, one model at a time
results: NSResultsType = {}