mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61323 Before this PR, all observers and fake quants were silently removed when adding loggers with NS. This is problematic for QAT models because we need the fake quants to run in order to properly capture intermediate outputs. This PR fixes the issue by preserving the observers throughout the passes which add loggers. In detail: * for each quantization module or fusion, add additional patterns with that fusion and an observer/fake_quant at the end * remove the places in the logger model creation code which removed observers * add unit testing that QAT numerics do not change after adding loggers Test Plan: ``` python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_loggers_preserve_qat_numerics python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_shadow_loggers_preserve_qat_numerics ``` Imported from OSS Reviewed By: hx89 Differential Revision: D29600351 fbshipit-source-id: 5f25118b79eb47860c49bca882de6a8eae7a4456
186 lines
7.3 KiB
Python
186 lines
7.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
toq = torch.ops.quantized
|
|
|
|
from torch.fx import GraphModule
|
|
from torch.fx.graph import Node
|
|
|
|
from .utils import getattr_from_fqn
|
|
from .ns_types import NSNodeTargetType
|
|
from torch.quantization.fx.pattern_utils import get_default_quant_patterns
|
|
from torch.quantization import (
|
|
ObserverBase,
|
|
FakeQuantizeBase,
|
|
)
|
|
|
|
from typing import Dict, Tuple, Set, Callable, Any, Union, List
|
|
|
|
|
|
def get_type_a_related_to_b(
|
|
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
|
|
) -> Set[Tuple[NSNodeTargetType, NSNodeTargetType]]:
|
|
# TODO(future PR): allow customizations
|
|
# TODO(future PR): reuse existing quantization mappings
|
|
# TODO(future PR): add the rest of modules and ops here
|
|
type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]] = set()
|
|
|
|
for base_name, s in base_name_to_sets_of_related_ops.items():
|
|
s_list = list(s)
|
|
# add every bidirectional pair
|
|
for idx_0 in range(0, len(s_list)):
|
|
for idx_1 in range(idx_0, len(s_list)):
|
|
type_a_related_to_b.add((s_list[idx_0], s_list[idx_1]))
|
|
type_a_related_to_b.add((s_list[idx_1], s_list[idx_0]))
|
|
|
|
return type_a_related_to_b
|
|
|
|
|
|
NSFusionElType = Union[
|
|
Callable, # call_function or call_module type, example: F.linear or nn.Conv2d
|
|
str, # call_method name, example: "dequantize"
|
|
Tuple[str, Any], # call_method name and first argument, example: ("to", torch.float16)
|
|
]
|
|
NSFusionType = Union[
|
|
Tuple[NSFusionElType, NSFusionElType],
|
|
Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType],
|
|
]
|
|
|
|
def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
|
|
"""
|
|
Set of potential fusions, in reverse order. The order is reversed
|
|
to match how fusion patterns are defined in quantization code.
|
|
|
|
Fusion format:
|
|
((fusion_op_0, fusion_op_1), base_op_idx)
|
|
|
|
Where base_op_idx is the idx of the op we should use to match other related
|
|
ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx
|
|
of 0 represents the first op in regular (non-reverse) order, 1 represents the
|
|
second op, etc.
|
|
"""
|
|
results: List[Tuple[NSFusionType, int]] = []
|
|
|
|
# Possible syntaxes:
|
|
# * single op: torch.nn.Conv2d
|
|
# * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d)
|
|
# For fusions, we only care about patterns composed of multiple ops.
|
|
# TODO(future PR): allow customizations from default patterns.
|
|
all_quant_patterns = get_default_quant_patterns()
|
|
default_base_op_idx = 0
|
|
for quant_pattern, _quant_handler in all_quant_patterns.items():
|
|
# Only patterns of multiple ops are fusions, ignore
|
|
# patterns which contain a single ops (they get matched
|
|
# without caring about fusions).
|
|
if isinstance(quant_pattern, tuple):
|
|
results.append((quant_pattern, default_base_op_idx)) # type: ignore[arg-type]
|
|
|
|
# For each pattern, add additional patterns with observers and
|
|
# fake quants at the end.
|
|
# TODO(future PR): if needed, implement matching for a node
|
|
# having multiple output observers.
|
|
for cls in (ObserverBase, FakeQuantizeBase):
|
|
if isinstance(quant_pattern, tuple):
|
|
new_pattern = (cls, *quant_pattern)
|
|
else:
|
|
new_pattern = (cls, quant_pattern)
|
|
results.append((new_pattern, default_base_op_idx)) # type: ignore[arg-type]
|
|
|
|
|
|
# After this point, results countains values such as
|
|
# [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...]
|
|
|
|
# Patterns for matching fp16 emulation are not specified in the quantization
|
|
# fusion mappings. For now, define them here.
|
|
fp16_em_base_op_idx = 1
|
|
patterns_to_add = [
|
|
# linear-relu fp16 emulation:
|
|
# fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
|
|
((("to", torch.float16), F.relu, F.linear, "dequantize"), fp16_em_base_op_idx,),
|
|
# Conv-BN fusion (this happens outside of quantization patterns,
|
|
# which is why it is defined separately here).
|
|
((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
|
|
((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
|
|
((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
|
|
((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
|
|
((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
|
|
((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
|
|
]
|
|
for p in patterns_to_add:
|
|
results.append(p) # type: ignore[arg-type]
|
|
results.append(((ObserverBase, *p[0]), p[1])) # type: ignore[arg-type]
|
|
results.append(((FakeQuantizeBase, *p[0]), p[1])) # type: ignore[arg-type]
|
|
|
|
return results
|
|
|
|
|
|
def end_node_matches_reversed_fusion(
|
|
end_node: Node,
|
|
reversed_fusion: NSFusionType,
|
|
gm: GraphModule,
|
|
) -> bool:
|
|
"""
|
|
Returns true if a pattern ending with `end_node` matches
|
|
the fusion pattern.
|
|
"""
|
|
cur_node = end_node
|
|
for fusion_idx in range(len(reversed_fusion)):
|
|
cur_fusion_el = reversed_fusion[fusion_idx]
|
|
|
|
if cur_node.op == 'call_function':
|
|
fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and \
|
|
(not isinstance(cur_fusion_el, type))
|
|
if fusion_el_is_fun:
|
|
if cur_node.target != cur_fusion_el:
|
|
return False
|
|
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
|
|
cur_node = cur_node.args[0]
|
|
else:
|
|
return False
|
|
else:
|
|
return False
|
|
|
|
elif cur_node.op == 'call_module':
|
|
fusion_el_is_mod = isinstance(cur_fusion_el, type)
|
|
if fusion_el_is_mod:
|
|
assert isinstance(cur_node.target, str)
|
|
target_mod = getattr_from_fqn(gm, cur_node.target)
|
|
if not isinstance(cur_fusion_el, type):
|
|
return False
|
|
if not isinstance(target_mod, cur_fusion_el):
|
|
return False
|
|
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
|
|
cur_node = cur_node.args[0]
|
|
else:
|
|
return False
|
|
else:
|
|
return False
|
|
|
|
elif cur_node.op == 'call_method':
|
|
fusion_el_is_meth_with_second_arg = \
|
|
isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2
|
|
fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str)
|
|
if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg:
|
|
if fusion_el_is_meth_without_args:
|
|
if cur_node.target != cur_fusion_el:
|
|
return False
|
|
else:
|
|
assert isinstance(cur_fusion_el, tuple)
|
|
if cur_node.target != cur_fusion_el[0]:
|
|
return False
|
|
elif len(cur_node.args) < 2:
|
|
return False
|
|
elif cur_node.args[1] != cur_fusion_el[1]:
|
|
return False
|
|
|
|
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
|
|
cur_node = cur_node.args[0]
|
|
else:
|
|
return False
|
|
else:
|
|
return False
|
|
else:
|
|
return False
|
|
|
|
return True
|