ns for fx: move pattern utils to separate file (#55805)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55805

No logic change, just moving util functions to separate file.

Test Plan:
```
python test/test_quantization.py TestFXGraphMatcher
python test/test_quantization.py TestFXNumericSuiteCoreAPIs
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D27719982

fbshipit-source-id: c80d5397c1efeb9fc83eacaa532ecbde557cca3f
This commit is contained in:
Vasiliy Kuznetsov 2021-04-15 16:01:22 -07:00 committed by Facebook GitHub Bot
parent b461104554
commit c8209a7336
2 changed files with 271 additions and 248 deletions

View File

@ -4,12 +4,6 @@ import operator
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.qat as nniqat
import torch.nn.intrinsic as nni
toq = torch.ops.quantized
from torch.fx import GraphModule
@ -17,130 +11,18 @@ from torch.fx.graph import Graph, Node
from .utils import getattr_from_fqn
from .ns_types import NSSubgraph
from torch.quantization.fx.pattern_utils import get_default_quant_patterns
from .pattern_utils import (
get_base_name_to_sets_of_related_ops,
get_type_a_related_to_b,
get_reversed_fusions,
end_node_matches_reversed_fusion,
)
from typing import Dict, Tuple, List, Optional, Set, Callable, Any, Union
from typing import Dict, Tuple, List, Optional, Set, Callable, Any
def _get_output_nodes(g: Graph) -> List[Node]:
return [n for n in g.nodes if n.op == 'output']
def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[Callable]]:
base_name_to_sets_of_related_ops: Dict[str, Set[Callable]] = {
# conv modules
'torch.nn.Conv1d': set([
nn.Conv1d,
nnq.Conv1d,
nniqat.ConvBn1d,
nniqat.ConvBnReLU1d,
nniq.ConvReLU1d,
nni.ConvReLU1d,
]),
'torch.nn.Conv2d': set([
nn.Conv2d,
nnq.Conv2d,
nnqat.Conv2d,
nniqat.ConvBn2d,
nniqat.ConvBnReLU2d,
nniqat.ConvReLU2d,
nniq.ConvReLU2d,
nni.ConvReLU2d,
]),
'torch.nn.Conv3d': set([
nn.Conv3d,
nnq.Conv3d,
nnqat.Conv3d,
nniqat.ConvBn3d,
nniqat.ConvBnReLU3d,
nniqat.ConvReLU3d,
nniq.ConvReLU3d,
nni.ConvReLU3d,
]),
# conv functionals
'torch.nn.functional.conv1d': set([
F.conv1d,
toq.conv1d,
toq.conv1d_relu,
]),
'torch.nn.functional.conv2d': set([
F.conv2d,
toq.conv2d,
toq.conv2d_relu,
]),
'torch.nn.functional.conv3d': set([
F.conv3d,
toq.conv3d,
toq.conv3d_relu,
]),
# linear modules
'torch.nn.Linear': set([
nn.Linear,
nnq.Linear,
nni.LinearReLU,
nniq.LinearReLU,
nnqat.Linear,
nnqd.Linear,
nniqat.LinearReLU,
]),
# linear functionals
'torch.nn.functional.linear': set([
F.linear,
toq.linear,
toq.linear_relu,
]),
# LSTM
'torch.nn.LSTM': set([
nn.LSTM,
nnqd.LSTM,
]),
# add
'torch.add': set([
torch.add,
toq.add,
operator.add, # x + y
]),
# cat
'torch.cat': set([
torch.cat,
toq.cat,
]),
# mul
'torch.mul': set([
torch.mul,
toq.mul,
]),
# relu
'torch.relu': set([
F.relu,
]),
# maxpool2d
'torch.nn.MaxPool2d': set([
nn.MaxPool2d,
]),
# sigmoid
'torch.sigmoid': set([
torch.sigmoid,
]),
}
return base_name_to_sets_of_related_ops
def get_type_a_related_to_b(
base_name_to_sets_of_related_ops: Dict[str, Set[Callable]],
) -> Set[Tuple[Callable, Callable]]:
# TODO(future PR): allow customizations
# TODO(future PR): reuse existing quantization mappings
# TODO(future PR): add the rest of modules and ops here
type_a_related_to_b: Set[Tuple[Callable, Callable]] = set()
for base_name, s in base_name_to_sets_of_related_ops.items():
s_list = list(s)
# add every bidirectional pair
for idx_0 in range(0, len(s_list) - 1):
for idx_1 in range(idx_0 + 1, len(s_list)):
type_a_related_to_b.add((s_list[idx_0], s_list[idx_1]))
type_a_related_to_b.add((s_list[idx_1], s_list[idx_0]))
return type_a_related_to_b
def get_non_matchable_functions() -> Set[Callable]:
"""
`call_function` nodes pointing to these functions are non-matchable.
@ -161,129 +43,6 @@ def get_non_matchable_modules() -> Set[Callable]:
torch.quantization.FakeQuantizeBase,
])
NSFusionElType = Union[
Callable, # call_function or call_module type, example: F.linear or nn.Conv2d
str, # call_method name, example: "dequantize"
Tuple[str, Any], # call_method name and first argument, example: ("to", torch.float16)
]
NSFusionType = Union[
Tuple[NSFusionElType, NSFusionElType],
Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType],
]
def get_reversed_fusions() -> Set[Tuple[NSFusionType, int]]:
"""
Set of potential fusions, in reverse order. The order is reversed
to match how fusion patterns are defined in quantization code.
Fusion format:
((fusion_op_0, fusion_op_1), base_op_idx)
Where base_op_idx is the idx of the op we should use to match other related
ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx
of 0 represents the first op in regular (non-reverse) order, 1 represents the
second op, etc.
"""
results: Set[Tuple[NSFusionType, int]] = set([])
# Possible syntaxes:
# * single op: torch.nn.Conv2d
# * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d)
# For fusions, we only care about patterns composed of multiple ops.
# TODO(future PR): allow customizations from default patterns.
all_quant_patterns = get_default_quant_patterns()
default_base_op_idx = 0
for quant_pattern, _quant_handler in all_quant_patterns.items():
# this only takes patterns of multiple ops
if isinstance(quant_pattern, tuple):
results.add((quant_pattern, default_base_op_idx)) # type: ignore
# After this point, results countains values such as
# [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...]
# Patterns for matching fp16 emulation are not specified in the quantization
# fusion mappings. For now, define them here.
fp16_em_base_op_idx = 1
patterns_to_add = [
# linear-relu fp16 emulation:
# fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
((("to", torch.float16), F.relu, F.linear, "dequantize"), fp16_em_base_op_idx,),
]
for p in patterns_to_add:
results.add(p)
return results
def end_node_matches_reversed_fusion(
end_node: Node,
reversed_fusion: NSFusionType,
gm: GraphModule,
) -> bool:
"""
Returns true if a pattern ending with `end_node` matches
the fusion pattern.
"""
cur_node = end_node
for fusion_idx in range(len(reversed_fusion)):
cur_fusion_el = reversed_fusion[fusion_idx]
if cur_node.op == 'call_function':
fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and \
(not isinstance(cur_fusion_el, type))
if fusion_el_is_fun:
if cur_node.target != cur_fusion_el:
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
elif cur_node.op == 'call_module':
fusion_el_is_mod = isinstance(cur_fusion_el, type)
if fusion_el_is_mod:
assert isinstance(cur_node.target, str)
target_mod = getattr_from_fqn(gm, cur_node.target)
if not isinstance(cur_fusion_el, type):
return False
if not isinstance(target_mod, cur_fusion_el):
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
elif cur_node.op == 'call_method':
fusion_el_is_meth_with_second_arg = \
isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2
fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str)
if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg:
if fusion_el_is_meth_without_args:
if cur_node.target != cur_fusion_el:
return False
else:
assert isinstance(cur_fusion_el, tuple)
if cur_node.target != cur_fusion_el[0]:
return False
elif len(cur_node.args) < 2:
return False
elif cur_node.args[1] != cur_fusion_el[1]:
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
else:
return False
return True
class _NSGraphMatchableSubgraphsIterator:
"""

View File

@ -0,0 +1,264 @@
import operator
import torch
import torch.nn as nn
import torch.nn.functional as F
toq = torch.ops.quantized
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.qat as nniqat
import torch.nn.intrinsic as nni
from torch.fx import GraphModule
from torch.fx.graph import Node
from .utils import getattr_from_fqn
from torch.quantization.fx.pattern_utils import get_default_quant_patterns
from typing import Dict, Tuple, Set, Callable, Any, Union
def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[Callable]]:
base_name_to_sets_of_related_ops: Dict[str, Set[Callable]] = {
# conv modules
'torch.nn.Conv1d': set([
nn.Conv1d,
nnq.Conv1d,
nniqat.ConvBn1d,
nniqat.ConvBnReLU1d,
nniq.ConvReLU1d,
nni.ConvReLU1d,
]),
'torch.nn.Conv2d': set([
nn.Conv2d,
nnq.Conv2d,
nnqat.Conv2d,
nniqat.ConvBn2d,
nniqat.ConvBnReLU2d,
nniqat.ConvReLU2d,
nniq.ConvReLU2d,
nni.ConvReLU2d,
]),
'torch.nn.Conv3d': set([
nn.Conv3d,
nnq.Conv3d,
nnqat.Conv3d,
nniqat.ConvBn3d,
nniqat.ConvBnReLU3d,
nniqat.ConvReLU3d,
nniq.ConvReLU3d,
nni.ConvReLU3d,
]),
# conv functionals
'torch.nn.functional.conv1d': set([
F.conv1d,
toq.conv1d,
toq.conv1d_relu,
]),
'torch.nn.functional.conv2d': set([
F.conv2d,
toq.conv2d,
toq.conv2d_relu,
]),
'torch.nn.functional.conv3d': set([
F.conv3d,
toq.conv3d,
toq.conv3d_relu,
]),
# linear modules
'torch.nn.Linear': set([
nn.Linear,
nnq.Linear,
nni.LinearReLU,
nniq.LinearReLU,
nnqat.Linear,
nnqd.Linear,
nniqat.LinearReLU,
]),
# linear functionals
'torch.nn.functional.linear': set([
F.linear,
toq.linear,
toq.linear_relu,
]),
# LSTM
'torch.nn.LSTM': set([
nn.LSTM,
nnqd.LSTM,
]),
# add
'torch.add': set([
torch.add,
toq.add,
operator.add, # x + y
]),
# cat
'torch.cat': set([
torch.cat,
toq.cat,
]),
# mul
'torch.mul': set([
torch.mul,
toq.mul,
]),
# relu
'torch.relu': set([
F.relu,
]),
# maxpool2d
'torch.nn.MaxPool2d': set([
nn.MaxPool2d,
]),
# sigmoid
'torch.sigmoid': set([
torch.sigmoid,
]),
}
return base_name_to_sets_of_related_ops
def get_type_a_related_to_b(
base_name_to_sets_of_related_ops: Dict[str, Set[Callable]],
) -> Set[Tuple[Callable, Callable]]:
# TODO(future PR): allow customizations
# TODO(future PR): reuse existing quantization mappings
# TODO(future PR): add the rest of modules and ops here
type_a_related_to_b: Set[Tuple[Callable, Callable]] = set()
for base_name, s in base_name_to_sets_of_related_ops.items():
s_list = list(s)
# add every bidirectional pair
for idx_0 in range(0, len(s_list) - 1):
for idx_1 in range(idx_0 + 1, len(s_list)):
type_a_related_to_b.add((s_list[idx_0], s_list[idx_1]))
type_a_related_to_b.add((s_list[idx_1], s_list[idx_0]))
return type_a_related_to_b
NSFusionElType = Union[
Callable, # call_function or call_module type, example: F.linear or nn.Conv2d
str, # call_method name, example: "dequantize"
Tuple[str, Any], # call_method name and first argument, example: ("to", torch.float16)
]
NSFusionType = Union[
Tuple[NSFusionElType, NSFusionElType],
Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType],
]
def get_reversed_fusions() -> Set[Tuple[NSFusionType, int]]:
"""
Set of potential fusions, in reverse order. The order is reversed
to match how fusion patterns are defined in quantization code.
Fusion format:
((fusion_op_0, fusion_op_1), base_op_idx)
Where base_op_idx is the idx of the op we should use to match other related
ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx
of 0 represents the first op in regular (non-reverse) order, 1 represents the
second op, etc.
"""
results: Set[Tuple[NSFusionType, int]] = set([])
# Possible syntaxes:
# * single op: torch.nn.Conv2d
# * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d)
# For fusions, we only care about patterns composed of multiple ops.
# TODO(future PR): allow customizations from default patterns.
all_quant_patterns = get_default_quant_patterns()
default_base_op_idx = 0
for quant_pattern, _quant_handler in all_quant_patterns.items():
# this only takes patterns of multiple ops
if isinstance(quant_pattern, tuple):
results.add((quant_pattern, default_base_op_idx)) # type: ignore
# After this point, results countains values such as
# [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...]
# Patterns for matching fp16 emulation are not specified in the quantization
# fusion mappings. For now, define them here.
fp16_em_base_op_idx = 1
patterns_to_add = [
# linear-relu fp16 emulation:
# fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
((("to", torch.float16), F.relu, F.linear, "dequantize"), fp16_em_base_op_idx,),
]
for p in patterns_to_add:
results.add(p)
return results
def end_node_matches_reversed_fusion(
end_node: Node,
reversed_fusion: NSFusionType,
gm: GraphModule,
) -> bool:
"""
Returns true if a pattern ending with `end_node` matches
the fusion pattern.
"""
cur_node = end_node
for fusion_idx in range(len(reversed_fusion)):
cur_fusion_el = reversed_fusion[fusion_idx]
if cur_node.op == 'call_function':
fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and \
(not isinstance(cur_fusion_el, type))
if fusion_el_is_fun:
if cur_node.target != cur_fusion_el:
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
elif cur_node.op == 'call_module':
fusion_el_is_mod = isinstance(cur_fusion_el, type)
if fusion_el_is_mod:
assert isinstance(cur_node.target, str)
target_mod = getattr_from_fqn(gm, cur_node.target)
if not isinstance(cur_fusion_el, type):
return False
if not isinstance(target_mod, cur_fusion_el):
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
elif cur_node.op == 'call_method':
fusion_el_is_meth_with_second_arg = \
isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2
fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str)
if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg:
if fusion_el_is_meth_without_args:
if cur_node.target != cur_fusion_el:
return False
else:
assert isinstance(cur_fusion_el, tuple)
if cur_node.target != cur_fusion_el[0]:
return False
elif len(cur_node.args) < 2:
return False
elif cur_node.args[1] != cur_fusion_el[1]:
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
else:
return False
return True