pytorch/torch/ao/ns/fx/pattern_utils.py
Andrew Or ee9335a608 [Quant][fx] Define native backend_config_dict for linear and conv (#74636)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74636

This commit changes how quantization patterns for linear
and conv are set up in prepare. Previously, these were set up
through ConvReluQuantizeHandler and LinearReLUQuantizeHandler.
After this commit, however, these were set up through the
corresponding entries in the native backend_config_dict,
rendering the above quantize handlers no longer necessary.
In future commits, we will do the same for the remaining ops.

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps

Imported from OSS

Reviewed By: jerryzh168, ngimel

Differential Revision: D35225680

fbshipit-source-id: 4a79f63a11fce46701eb17aaf3619c1e827d72a4
(cherry picked from commit 475f599821cd32d3ba71ba086885ecdc4cbee755)
2022-04-04 14:07:15 +00:00

192 lines
7.4 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
toq = torch.ops.quantized
from torch.fx import GraphModule
from torch.fx.graph import Node
from torch.ao.quantization.utils import getattr_from_fqn
from .ns_types import NSNodeTargetType
from torch.ao.quantization.fx.backend_config.utils import get_native_quant_patterns
from torch.ao.quantization import (
ObserverBase,
FakeQuantizeBase,
)
from typing import Dict, Tuple, Set, Callable, Any, Union, List
def get_type_a_related_to_b(
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
) -> Set[Tuple[NSNodeTargetType, NSNodeTargetType]]:
# TODO(future PR): allow customizations
# TODO(future PR): reuse existing quantization mappings
# TODO(future PR): add the rest of modules and ops here
type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]] = set()
for base_name, s in base_name_to_sets_of_related_ops.items():
s_list = list(s)
# add every bidirectional pair
for idx_0 in range(0, len(s_list)):
for idx_1 in range(idx_0, len(s_list)):
type_a_related_to_b.add((s_list[idx_0], s_list[idx_1]))
type_a_related_to_b.add((s_list[idx_1], s_list[idx_0]))
return type_a_related_to_b
NSFusionElType = Union[
Callable, # call_function or call_module type, example: F.linear or nn.Conv2d
str, # call_method name, example: "dequantize"
Tuple[str, Any], # call_method name and first argument, example: ("to", torch.float16)
]
NSFusionType = Union[
Tuple[NSFusionElType, NSFusionElType],
Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType],
]
def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
"""
Set of potential fusions, in reverse order. The order is reversed
to match how fusion patterns are defined in quantization code.
Fusion format:
((fusion_op_0, fusion_op_1), base_op_idx)
Where base_op_idx is the idx of the op we should use to match other related
ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx
of 0 represents the first op in regular (non-reverse) order, 1 represents the
second op, etc.
"""
results: List[Tuple[NSFusionType, int]] = []
# Possible syntaxes:
# * single op: torch.nn.Conv2d
# * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d)
# For fusions, we only care about patterns composed of multiple ops.
# TODO(future PR): allow customizations from default patterns.
all_quant_patterns = get_native_quant_patterns()
default_base_op_idx = 0
for quant_pattern, _quant_handler in all_quant_patterns.items():
# Only patterns of multiple ops are fusions, ignore
# patterns which contain a single ops (they get matched
# without caring about fusions).
if isinstance(quant_pattern, tuple):
results.append((quant_pattern, default_base_op_idx)) # type: ignore[arg-type]
# For each pattern, add additional patterns with observers and
# fake quants at the end.
# TODO(future PR): if needed, implement matching for a node
# having multiple output observers.
for cls in (ObserverBase, FakeQuantizeBase):
if isinstance(quant_pattern, tuple):
new_pattern = (cls, *quant_pattern)
else:
new_pattern = (cls, quant_pattern)
results.append((new_pattern, default_base_op_idx)) # type: ignore[arg-type]
# After this point, results countains values such as
# [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...]
# Patterns for matching fp16 emulation are not specified in the quantization
# fusion mappings. For now, define them here.
fp16_em_base_op_idx = 1
patterns_to_add = [
# linear-relu fp16 emulation:
# fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
((("to", torch.float16), F.relu, F.linear, "dequantize"), fp16_em_base_op_idx,),
# Conv-BN fusion (this happens outside of quantization patterns,
# which is why it is defined separately here).
((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
]
for p in patterns_to_add:
results.append(p) # type: ignore[arg-type]
results.append(((ObserverBase, *p[0]), p[1])) # type: ignore[arg-type]
results.append(((FakeQuantizeBase, *p[0]), p[1])) # type: ignore[arg-type]
return results
def end_node_matches_reversed_fusion(
end_node: Node,
reversed_fusion: NSFusionType,
gm: GraphModule,
seen_nodes: Set[Node],
) -> bool:
"""
Returns true if a pattern ending with `end_node` matches
the fusion pattern.
"""
cur_node = end_node
for fusion_idx in range(len(reversed_fusion)):
# each node can only belong to one matched pattern
if cur_node in seen_nodes:
return False
cur_fusion_el = reversed_fusion[fusion_idx]
if cur_node.op == 'call_function':
fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and \
(not isinstance(cur_fusion_el, type))
if fusion_el_is_fun:
if cur_node.target != cur_fusion_el:
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
elif cur_node.op == 'call_module':
fusion_el_is_mod = isinstance(cur_fusion_el, type)
if fusion_el_is_mod:
assert isinstance(cur_node.target, str)
target_mod = getattr_from_fqn(gm, cur_node.target)
if not isinstance(cur_fusion_el, type):
return False
if not isinstance(target_mod, cur_fusion_el):
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
elif cur_node.op == 'call_method':
fusion_el_is_meth_with_second_arg = \
isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2
fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str)
if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg:
if fusion_el_is_meth_without_args:
if cur_node.target != cur_fusion_el:
return False
else:
assert isinstance(cur_fusion_el, tuple)
if cur_node.target != cur_fusion_el[0]:
return False
elif len(cur_node.args) < 2:
return False
elif cur_node.args[1] != cur_fusion_el[1]:
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
else:
return False
return True