mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
ns for fx: add partial support for subgraphs with base_op_node (#54254)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54254 In fp16 emulation, we now have patterns such as ``` ... -> dequantize -> linear -> relu -> to(torch.float16) -> ... ``` This PR adds support for * specifying a subgraph's "base_op_node", which is the node with the op which should be matched to related nodes. In the example above, "base_op_node" would be the linear node, and it would be the second node in the matched pattern. * matching these fusion patterns and properly setting "base_op_node" based on pattern and index * using "base_op_node" instead of "start_node" throughout the NS codebase wherever the intent is to match subgraphs or create names for subgraphs. At the end of this PR, matching unshadowed activations with an example fp16 emulation pattern works e2e. I'm saving the following work for future PRs (soon), mostly to keep PR size manageable: * adding weight matching (will require some changes to function which extracts weights) * adding shadowed activation matching (will require some changes to shadow copying) * adding input logging for these patterns (will likely require some changes as well) Test Plan: ``` python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_linear_fp16 ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D27158199 fbshipit-source-id: 49fc445395452fda62e3c7a243544190f9af691c
This commit is contained in:
parent
454832e5fa
commit
182d8c375c
|
|
@ -539,6 +539,31 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
|||
results_len=2,
|
||||
should_log_inputs=True)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_linear_fp16(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.w1 = nn.Parameter(torch.Tensor(4, 4))
|
||||
self.b1 = nn.Parameter(torch.zeros(4))
|
||||
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
|
||||
|
||||
def forward(self, x):
|
||||
x = F.linear(x, self.w1, self.b1)
|
||||
x = F.relu(x)
|
||||
return x
|
||||
|
||||
qconfig_dict = {'': torch.quantization.float16_static_qconfig}
|
||||
m = M().eval()
|
||||
expected_occurrence = {
|
||||
ns.call_module(OutputLogger): 1,
|
||||
}
|
||||
self._test_match_activations(
|
||||
m, (torch.randn(4, 4),),
|
||||
prepared_expected_node_occurrence=expected_occurrence,
|
||||
results_len=1,
|
||||
qconfig_dict=qconfig_dict)
|
||||
|
||||
|
||||
class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from torch.fx.graph import Graph, Node
|
|||
from .utils import getattr_from_fqn
|
||||
from .ns_types import NSSubgraph
|
||||
|
||||
from typing import Dict, Tuple, List, Optional, Set, Callable, Any
|
||||
from typing import Dict, Tuple, List, Optional, Set, Callable, Any, Union
|
||||
|
||||
def _get_output_nodes(g: Graph) -> List[Node]:
|
||||
return [n for n in g.nodes if n.op == 'output']
|
||||
|
|
@ -108,54 +108,110 @@ def get_non_matchable_modules() -> Set[Callable]:
|
|||
torch.quantization.FakeQuantizeBase,
|
||||
])
|
||||
|
||||
def get_reversed_fusions() -> Set[Tuple[Callable, Callable]]:
|
||||
NSFusionElType = Union[
|
||||
Callable, # call_function or call_module type, example: F.linear or nn.Conv2d
|
||||
str, # call_method name, example: "dequantize"
|
||||
Tuple[str, Any], # call_method name and first argument, example: ("to", torch.float16)
|
||||
]
|
||||
NSFusionType = Union[
|
||||
Tuple[NSFusionElType, NSFusionElType],
|
||||
Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType],
|
||||
]
|
||||
|
||||
def get_reversed_fusions() -> Set[Tuple[NSFusionType, int]]:
|
||||
"""
|
||||
Set of potential fusions, in reverse order. The order is reversed
|
||||
to match how fusion patterns are defined in quantization code.
|
||||
|
||||
Fusion format:
|
||||
((fusion_op_0, fusion_op_1), base_op_idx)
|
||||
|
||||
Where base_op_idx is the idx of the op we should use to match other related
|
||||
ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx
|
||||
of 0 represents the first op in regular (non-reverse) order, 1 represents the
|
||||
second op, etc.
|
||||
"""
|
||||
# TODO(future PR): remove the custom syntax for defining fusion patterns
|
||||
# and reuse either quantization's syntax or something else.
|
||||
return set([
|
||||
(F.relu, F.linear),
|
||||
(nn.ReLU, nn.Conv2d),
|
||||
((F.relu, F.linear), 0),
|
||||
((nn.ReLU, nn.Conv2d), 0),
|
||||
# linear-relu fp16 emulation:
|
||||
# fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
|
||||
((("to", torch.float16), F.relu, F.linear, "dequantize"), 1),
|
||||
])
|
||||
|
||||
# TODO(future PR): we should see if we can reuse quantization's fusion
|
||||
# patterns here.
|
||||
def end_node_matches_reversed_fusion(
|
||||
end_node: Node,
|
||||
reversed_fusion: Tuple[Callable, Callable],
|
||||
reversed_fusion: NSFusionType,
|
||||
gm: GraphModule,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns true if a pattern ending with `end_node` matches
|
||||
the fusion pattern.
|
||||
"""
|
||||
if end_node.op == 'call_function':
|
||||
cur_node = end_node
|
||||
for fusion_idx in range(len(reversed_fusion)):
|
||||
cur_fusion_op = reversed_fusion[fusion_idx]
|
||||
if cur_node.target != cur_fusion_op:
|
||||
return False
|
||||
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
|
||||
cur_node = cur_node.args[0]
|
||||
cur_node = end_node
|
||||
for fusion_idx in range(len(reversed_fusion)):
|
||||
cur_fusion_el = reversed_fusion[fusion_idx]
|
||||
|
||||
if cur_node.op == 'call_function':
|
||||
fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and \
|
||||
(not isinstance(cur_fusion_el, type))
|
||||
if fusion_el_is_fun:
|
||||
if cur_node.target != cur_fusion_el:
|
||||
return False
|
||||
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
|
||||
cur_node = cur_node.args[0]
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
elif end_node.op == 'call_module':
|
||||
cur_node = end_node
|
||||
for fusion_idx in range(len(reversed_fusion)):
|
||||
cur_fusion_op = reversed_fusion[fusion_idx]
|
||||
assert isinstance(cur_node.target, str)
|
||||
target_mod = getattr_from_fqn(gm, cur_node.target)
|
||||
if not isinstance(cur_fusion_op, type):
|
||||
return False
|
||||
if not isinstance(target_mod, cur_fusion_op):
|
||||
return False
|
||||
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
|
||||
cur_node = cur_node.args[0]
|
||||
|
||||
elif cur_node.op == 'call_module':
|
||||
fusion_el_is_mod = isinstance(cur_fusion_el, type)
|
||||
if fusion_el_is_mod:
|
||||
assert isinstance(cur_node.target, str)
|
||||
target_mod = getattr_from_fqn(gm, cur_node.target)
|
||||
if not isinstance(cur_fusion_el, type):
|
||||
return False
|
||||
if not isinstance(target_mod, cur_fusion_el):
|
||||
return False
|
||||
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
|
||||
cur_node = cur_node.args[0]
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
elif cur_node.op == 'call_method':
|
||||
fusion_el_is_meth_with_second_arg = \
|
||||
isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2
|
||||
fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str)
|
||||
if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg:
|
||||
if fusion_el_is_meth_without_args:
|
||||
if cur_node.target != cur_fusion_el:
|
||||
return False
|
||||
else:
|
||||
assert isinstance(cur_fusion_el, tuple)
|
||||
if cur_node.target != cur_fusion_el[0]:
|
||||
return False
|
||||
elif len(cur_node.args) < 2:
|
||||
return False
|
||||
elif cur_node.args[1] != cur_fusion_el[1]:
|
||||
return False
|
||||
|
||||
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
|
||||
cur_node = cur_node.args[0]
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class _NSGraphMatchableSubgraphsIterator:
|
||||
|
|
@ -195,22 +251,33 @@ class _NSGraphMatchableSubgraphsIterator:
|
|||
# for subgraphs which are single nodes, start_node == end_node
|
||||
# for subgraphs with more than one node, start node != end_node
|
||||
cur_start_node = cur_end_node
|
||||
# Subgraphs like linear-relu have the base node as the start node.
|
||||
# Subgraphs like dequantize-linear-relu-to(torch.float16) have the
|
||||
# base node as the second node.
|
||||
# The cur_base_op_node var will move to the actual node during
|
||||
# the fusion matching later in this code block.
|
||||
cur_base_op_node = cur_end_node
|
||||
|
||||
# Check for potential fusions. For now, we are greedy
|
||||
# and always skip all non-base nodes of a fusion. For example,
|
||||
# if we match linear-relu backwards, we will always skip the
|
||||
# relu node and attempt to match the linear node. This can
|
||||
# be made configurable later if needed.
|
||||
for _reverse_fusion_ops in get_reversed_fusions():
|
||||
for _reverse_fusion_ops, base_op_idx in get_reversed_fusions():
|
||||
is_match = end_node_matches_reversed_fusion(
|
||||
cur_end_node, _reverse_fusion_ops, self.gm)
|
||||
if is_match:
|
||||
# navigate to the base node
|
||||
for fusion_idx in range(len(_reverse_fusion_ops) - 1):
|
||||
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
|
||||
self.seen_nodes.add(cur_start_node)
|
||||
# for now, assume that there are no other nodes
|
||||
# which need to be added to the stack
|
||||
cur_start_node = cur_start_node.args[0] # type: ignore
|
||||
# if the base op index matches the current node, set it
|
||||
rev_base_op_idx = \
|
||||
len(_reverse_fusion_ops) - 2 - base_op_idx
|
||||
if rev_fusion_idx == rev_base_op_idx:
|
||||
cur_base_op_node = cur_start_node
|
||||
break
|
||||
|
||||
self.seen_nodes.add(cur_start_node)
|
||||
|
|
@ -223,10 +290,12 @@ class _NSGraphMatchableSubgraphsIterator:
|
|||
# note: this check is done on the start_node, i.e.
|
||||
# if we are matching linear-relu in reverse, this would do the matchable
|
||||
# check on the linear
|
||||
if not self._is_matchable(cur_start_node):
|
||||
if not self._is_matchable(cur_base_op_node):
|
||||
continue
|
||||
|
||||
return NSSubgraph(start_node=cur_start_node, end_node=cur_end_node)
|
||||
return NSSubgraph(
|
||||
start_node=cur_start_node, end_node=cur_end_node,
|
||||
base_op_node=cur_base_op_node)
|
||||
|
||||
raise StopIteration
|
||||
|
||||
|
|
@ -263,38 +332,55 @@ class GraphMatchingException(Exception):
|
|||
"""
|
||||
pass
|
||||
|
||||
class NodeTypeRelationship(enum.Enum):
|
||||
class SugraphTypeRelationship(enum.Enum):
|
||||
# same type
|
||||
# example: F.linear and toq.linear, or nn.Conv2d and nn.Conv2d
|
||||
EQUAL = enum.auto()
|
||||
# same node_relationship set, but not the same type
|
||||
# same subgraph_relationship set, but not the same type
|
||||
# example: F.linear and toq.linear
|
||||
RELATED_BUT_NOT_EQUAL = enum.auto()
|
||||
# not related
|
||||
NOT_RELATED = enum.auto()
|
||||
|
||||
def _get_node_relationship_type(
|
||||
node_a: Node,
|
||||
node_b: Node,
|
||||
def _get_subgraph_relationship_type(
|
||||
subgraph_a: NSSubgraph,
|
||||
subgraph_b: NSSubgraph,
|
||||
gm_a: GraphModule,
|
||||
gm_b: GraphModule,
|
||||
type_a_related_to_b: Set[Tuple[Callable, Callable]],
|
||||
) -> NodeTypeRelationship:
|
||||
) -> SugraphTypeRelationship:
|
||||
node_a = subgraph_a.base_op_node
|
||||
node_b = subgraph_b.base_op_node
|
||||
|
||||
# TODO(next): make this code handle matching by what is before the base op
|
||||
if node_a.op != node_b.op:
|
||||
# for now, comparing call_module to call_function is not supported
|
||||
# this can be added later if needed
|
||||
return NodeTypeRelationship.NOT_RELATED
|
||||
return SugraphTypeRelationship.NOT_RELATED
|
||||
|
||||
if node_a.op == 'call_function':
|
||||
if node_a.target == node_b.target:
|
||||
# nodes with equivalent targets always match (i.e. F.linear and F.linear)
|
||||
return NodeTypeRelationship.EQUAL
|
||||
node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node
|
||||
node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node
|
||||
if node_a_has_prev and (not node_b_has_prev):
|
||||
return SugraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
||||
elif (not node_a_has_prev) and node_b_has_prev:
|
||||
return SugraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
||||
elif (not node_a_has_prev) and (not node_b_has_prev):
|
||||
return SugraphTypeRelationship.EQUAL
|
||||
else:
|
||||
# TODO(future PR): check for matches start_op_node and base_op_node
|
||||
return SugraphTypeRelationship.EQUAL
|
||||
|
||||
key = (node_a.target, node_b.target)
|
||||
if key in type_a_related_to_b:
|
||||
return NodeTypeRelationship.RELATED_BUT_NOT_EQUAL
|
||||
return SugraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
||||
else:
|
||||
return NodeTypeRelationship.NOT_RELATED
|
||||
return SugraphTypeRelationship.NOT_RELATED
|
||||
elif node_a.op == 'call_module':
|
||||
assert (subgraph_a.base_op_node == subgraph_a.start_node and
|
||||
subgraph_b.base_op_node == subgraph_b.start_node), \
|
||||
"Matching call_module patterns where base_op_node != start_node is not supported yet"
|
||||
# for call_module, we need to look up the modules to do the type check
|
||||
assert isinstance(node_a.target, str)
|
||||
mod_a = getattr_from_fqn(gm_a, node_a.target)
|
||||
|
|
@ -302,13 +388,13 @@ def _get_node_relationship_type(
|
|||
mod_b = getattr_from_fqn(gm_b, node_b.target)
|
||||
# modules with equivalent types always match (i.e. nn.Conv2d and nn.Conv2d)
|
||||
if type(mod_a) == type(mod_b):
|
||||
return NodeTypeRelationship.EQUAL
|
||||
return SugraphTypeRelationship.EQUAL
|
||||
key = (type(mod_a), type(mod_b))
|
||||
if key in type_a_related_to_b:
|
||||
return NodeTypeRelationship.RELATED_BUT_NOT_EQUAL
|
||||
return SugraphTypeRelationship.RELATED_BUT_NOT_EQUAL
|
||||
else:
|
||||
return NodeTypeRelationship.NOT_RELATED
|
||||
return NodeTypeRelationship.NOT_RELATED
|
||||
return SugraphTypeRelationship.NOT_RELATED
|
||||
return SugraphTypeRelationship.NOT_RELATED
|
||||
|
||||
def _get_name_for_subgraph(
|
||||
subgraph_a: NSSubgraph,
|
||||
|
|
@ -348,7 +434,7 @@ def _get_name_for_subgraph(
|
|||
of the graphs match, both of these subgraphs will get the same name without
|
||||
(1) and (2) knowing anything about each other.
|
||||
"""
|
||||
target_type = _get_node_target_type(subgraph_a.start_node, gm_a)
|
||||
target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a)
|
||||
target_base_type = None
|
||||
for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items():
|
||||
if target_type in sets_of_related_ops:
|
||||
|
|
@ -473,19 +559,19 @@ def get_matching_subgraph_pairs(
|
|||
|
||||
# check for results and determine what to do next
|
||||
if cur_subgraph_a is not None and cur_subgraph_b is not None:
|
||||
# both nodes were fetched, check for node_relationship
|
||||
# note: node_relationship is checked on the start node, i.e.
|
||||
# if a linear-relu pattern is checked, we would check for node_relationship
|
||||
# both nodes were fetched, check for subgraph_relationship
|
||||
# note: subgraph_relationship is checked on the start node, i.e.
|
||||
# if a linear-relu pattern is checked, we would check for subgraph_relationship
|
||||
# of the linear
|
||||
node_relationship = _get_node_relationship_type(
|
||||
cur_subgraph_a.start_node, cur_subgraph_b.start_node,
|
||||
subgraph_relationship = _get_subgraph_relationship_type(
|
||||
cur_subgraph_a, cur_subgraph_b,
|
||||
gm_a, gm_b, type_a_related_to_b)
|
||||
if node_relationship == NodeTypeRelationship.NOT_RELATED:
|
||||
if subgraph_relationship == SugraphTypeRelationship.NOT_RELATED:
|
||||
msg = f"""
|
||||
({cur_subgraph_a}, {type_start_a}) and
|
||||
({cur_subgraph_b}, {type_start_b}) are not related"""
|
||||
raise GraphMatchingException(msg)
|
||||
elif node_relationship == NodeTypeRelationship.EQUAL:
|
||||
elif subgraph_relationship == SugraphTypeRelationship.EQUAL:
|
||||
# For now, skip nodes with equal types. In the future, this can
|
||||
# be made configurable.
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -10,5 +10,5 @@ class NSSingleResultValuesType(str, enum.Enum):
|
|||
|
||||
NSSubgraph = NamedTuple(
|
||||
'NSSubgraph',
|
||||
[('start_node', Node), ('end_node', Node)]
|
||||
[('start_node', Node), ('end_node', Node), ('base_op_node', Node)]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -233,8 +233,8 @@ def _extract_weights_impl(
|
|||
nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
|
||||
for match_name, match in matched_subgraph_pairs.items():
|
||||
subgraph_a, subgraph_b = match
|
||||
nodes_and_names_to_instrument_a.append((subgraph_a.start_node, match_name))
|
||||
nodes_and_names_to_instrument_b.append((subgraph_b.start_node, match_name))
|
||||
nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
|
||||
nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
|
||||
|
||||
# populate the results, one model at a time
|
||||
results: NSResultsType = {}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user