pytorch/torch/quantization/ns/pattern_utils.py
Vasiliy Kuznetsov 4acd14da02 ns for fx: preserve observers and fake_quants through passes (#61323)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61323

Before this PR, all observers and fake quants were silently removed
when adding loggers with NS. This is problematic for QAT models because
we need the fake quants to run in order to properly capture intermediate
outputs.

This PR fixes the issue by preserving the observers throughout
the passes which add loggers.  In detail:
* for each quantization module or fusion, add additional patterns with that fusion and an observer/fake_quant at the end
* remove the places in the logger model creation code which removed observers
* add unit testing that QAT numerics do not change after adding loggers

Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_loggers_preserve_qat_numerics
python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_shadow_loggers_preserve_qat_numerics
```

Imported from OSS

Reviewed By: hx89

Differential Revision: D29600351

fbshipit-source-id: 5f25118b79eb47860c49bca882de6a8eae7a4456
2021-07-17 20:53:33 -07:00

186 lines
7.3 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
toq = torch.ops.quantized
from torch.fx import GraphModule
from torch.fx.graph import Node
from .utils import getattr_from_fqn
from .ns_types import NSNodeTargetType
from torch.quantization.fx.pattern_utils import get_default_quant_patterns
from torch.quantization import (
ObserverBase,
FakeQuantizeBase,
)
from typing import Dict, Tuple, Set, Callable, Any, Union, List
def get_type_a_related_to_b(
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
) -> Set[Tuple[NSNodeTargetType, NSNodeTargetType]]:
# TODO(future PR): allow customizations
# TODO(future PR): reuse existing quantization mappings
# TODO(future PR): add the rest of modules and ops here
type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]] = set()
for base_name, s in base_name_to_sets_of_related_ops.items():
s_list = list(s)
# add every bidirectional pair
for idx_0 in range(0, len(s_list)):
for idx_1 in range(idx_0, len(s_list)):
type_a_related_to_b.add((s_list[idx_0], s_list[idx_1]))
type_a_related_to_b.add((s_list[idx_1], s_list[idx_0]))
return type_a_related_to_b
NSFusionElType = Union[
Callable, # call_function or call_module type, example: F.linear or nn.Conv2d
str, # call_method name, example: "dequantize"
Tuple[str, Any], # call_method name and first argument, example: ("to", torch.float16)
]
NSFusionType = Union[
Tuple[NSFusionElType, NSFusionElType],
Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType],
]
def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
"""
Set of potential fusions, in reverse order. The order is reversed
to match how fusion patterns are defined in quantization code.
Fusion format:
((fusion_op_0, fusion_op_1), base_op_idx)
Where base_op_idx is the idx of the op we should use to match other related
ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx
of 0 represents the first op in regular (non-reverse) order, 1 represents the
second op, etc.
"""
results: List[Tuple[NSFusionType, int]] = []
# Possible syntaxes:
# * single op: torch.nn.Conv2d
# * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d)
# For fusions, we only care about patterns composed of multiple ops.
# TODO(future PR): allow customizations from default patterns.
all_quant_patterns = get_default_quant_patterns()
default_base_op_idx = 0
for quant_pattern, _quant_handler in all_quant_patterns.items():
# Only patterns of multiple ops are fusions, ignore
# patterns which contain a single ops (they get matched
# without caring about fusions).
if isinstance(quant_pattern, tuple):
results.append((quant_pattern, default_base_op_idx)) # type: ignore[arg-type]
# For each pattern, add additional patterns with observers and
# fake quants at the end.
# TODO(future PR): if needed, implement matching for a node
# having multiple output observers.
for cls in (ObserverBase, FakeQuantizeBase):
if isinstance(quant_pattern, tuple):
new_pattern = (cls, *quant_pattern)
else:
new_pattern = (cls, quant_pattern)
results.append((new_pattern, default_base_op_idx)) # type: ignore[arg-type]
# After this point, results countains values such as
# [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...]
# Patterns for matching fp16 emulation are not specified in the quantization
# fusion mappings. For now, define them here.
fp16_em_base_op_idx = 1
patterns_to_add = [
# linear-relu fp16 emulation:
# fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
((("to", torch.float16), F.relu, F.linear, "dequantize"), fp16_em_base_op_idx,),
# Conv-BN fusion (this happens outside of quantization patterns,
# which is why it is defined separately here).
((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
]
for p in patterns_to_add:
results.append(p) # type: ignore[arg-type]
results.append(((ObserverBase, *p[0]), p[1])) # type: ignore[arg-type]
results.append(((FakeQuantizeBase, *p[0]), p[1])) # type: ignore[arg-type]
return results
def end_node_matches_reversed_fusion(
end_node: Node,
reversed_fusion: NSFusionType,
gm: GraphModule,
) -> bool:
"""
Returns true if a pattern ending with `end_node` matches
the fusion pattern.
"""
cur_node = end_node
for fusion_idx in range(len(reversed_fusion)):
cur_fusion_el = reversed_fusion[fusion_idx]
if cur_node.op == 'call_function':
fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and \
(not isinstance(cur_fusion_el, type))
if fusion_el_is_fun:
if cur_node.target != cur_fusion_el:
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
elif cur_node.op == 'call_module':
fusion_el_is_mod = isinstance(cur_fusion_el, type)
if fusion_el_is_mod:
assert isinstance(cur_node.target, str)
target_mod = getattr_from_fqn(gm, cur_node.target)
if not isinstance(cur_fusion_el, type):
return False
if not isinstance(target_mod, cur_fusion_el):
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
elif cur_node.op == 'call_method':
fusion_el_is_meth_with_second_arg = \
isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2
fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str)
if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg:
if fusion_el_is_meth_without_args:
if cur_node.target != cur_fusion_el:
return False
else:
assert isinstance(cur_fusion_el, tuple)
if cur_node.target != cur_fusion_el[0]:
return False
elif len(cur_node.args) < 2:
return False
elif cur_node.args[1] != cur_fusion_el[1]:
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
else:
return False
return True