dbr quant function fusion [1/x]: record matches for functions (#71764)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71764

For DBR quant, adds the code for matching seen ops to function fusion
patterns. After we have the full DAG, we have a separate pass over the
dag and add matched fusion patterns to the seen op data structure.

This is the first PR in the stack which implements matching and
recording the match results. Future PRs in this stack will use
the match results to modify observer insertion and inference.

Test Plan:
```
python test/test_quantization.py TestQuantizeDBR.test_fusion_functions
```

Reviewed By: jerryzh168

Differential Revision: D33775098

Pulled By: vkuzo

fbshipit-source-id: 488aac902bf568d41c863ee49248990411ed9c53
This commit is contained in:
Vasiliy Kuznetsov 2022-02-07 05:55:40 -08:00 committed by Facebook GitHub Bot
parent b0d48a8e66
commit 4ad1ca1abc
7 changed files with 144 additions and 20 deletions

View File

@ -422,6 +422,25 @@ class TestQuantizeDBR(QuantizeDBRTestCase):
qconfig = torch.quantization.default_qconfig
self._test_auto_tracing(m, qconfig, (torch.randn(1, 1, 2, 2),))
def test_fusion_functions(self):
class M(torch.nn.Module):
def forward(self, x):
x = x + x
x = torch.relu(x)
return x
m = M().eval()
qconfig = torch.quantization.default_qconfig
mp = _quantize_dbr.prepare(m, {'': qconfig}, (torch.randn(1, 1, 1, 1),))
self.assertTrue(
mp._auto_quant_state.idx_to_seen_q_op_infos[0].fusion_info is not None)
self.assertTrue(
mp._auto_quant_state.idx_to_seen_q_op_infos[1].fusion_info is not None)
# TODO(future PR): use fusion results to insert observers
# TODO(future PR): use fusion results to replace function at inference
# TODO(future PR): use information about non-quantizeable ops during
# matching fusion patterns
def test_observers_not_touched_by_tracing(self):
"""
Verifies that running dynamic tracing does not change any data

View File

@ -381,6 +381,7 @@ def add_auto_observation(
if first_call:
for _, v in self.named_modules():
if hasattr(v, '_auto_quant_state'):
v._auto_quant_state.match_fusion_patterns()
v._auto_quant_state.insert_observers(v)
return output

View File

@ -0,0 +1,77 @@
from typing import Dict, Tuple, Callable, Optional
from .mappings import known_function_fusion_patterns_and_replacements
from .utils import (
FusionInfo,
SeenQOpInfo,
get_users_of_seen_q_op_info,
)
def _identity(x):
return x
def pattern_is_match(
fusion_pattern: Tuple[Callable, ...],
cur_seen_q_op_info: Optional[SeenQOpInfo],
idx_to_seen_q_op_infos: Dict[int, SeenQOpInfo],
) -> bool:
is_match = True
for el_type in fusion_pattern:
if cur_seen_q_op_info is not None and el_type == cur_seen_q_op_info.type:
next_seen_q_op_infos = get_users_of_seen_q_op_info(
idx_to_seen_q_op_infos, cur_seen_q_op_info)
if len(next_seen_q_op_infos) == 1:
cur_seen_q_op_info = next_seen_q_op_infos[0]
else:
cur_seen_q_op_info = None
continue
else:
is_match = False
break
return is_match
def match_fusion_patterns(
idx_to_seen_q_op_infos: Dict[int, SeenQOpInfo],
):
"""
Matches fusion patterns to elements of `idx_to_seen_q_op_infos`.
Modifies them inplace if matches are found.
Note:
1. The matching is local to the ops seen by a single parent module,
it does not cross module boundaries. This is for simplicity, and
there are no plans to relax this at the moment.
2. The matching only supports linear patterns of ops where all of
of the arguments needed to execute the fusion are passed to the first
op in the sequence. This is for simplicity, and can be relaxed
in a future PR if there is a need.
3. Currently the matching does not look at non quantizeable ops,
this will be fixed in the next PR.
"""
# Walk the subgraphs and find the function fusions. For now, this is
# brute forced for simplicity, can be optimized later if necessary.
for idx, seen_q_op_info in idx_to_seen_q_op_infos.items():
for fusion_pattern, replacement in \
known_function_fusion_patterns_and_replacements.items():
is_match = pattern_is_match(
fusion_pattern, seen_q_op_info, idx_to_seen_q_op_infos)
if not is_match:
continue
cur_seen_q_op_info = seen_q_op_info
for idx in range(len(fusion_pattern)):
if idx > 0:
users = get_users_of_seen_q_op_info(
idx_to_seen_q_op_infos, cur_seen_q_op_info)
cur_seen_q_op_info = users[0]
is_first_element = idx == 0
is_last_element = idx == len(fusion_pattern) - 1
replacement_type = replacement if is_first_element \
else _identity
fusion_info = FusionInfo(
fusion_pattern, replacement_type, is_first_element,
is_last_element)
cur_seen_q_op_info.fusion_info = fusion_info
break

View File

@ -2,6 +2,8 @@ from typing import List
import torch
from .function_fusion import pattern_is_match
from .utils import (
get_users_of_seen_q_op_info,
)
@ -19,8 +21,6 @@ def get_module_fusion_fqns(
Walks the subgraphs and determines which modules should be
fused.
TODO: test coverage
Output: a list of FQNs of modules which should be fused.
"""
results = []
@ -32,25 +32,21 @@ def get_module_fusion_fqns(
# Walk the subgraphs and record the FQNs of all known module fusions.
# For now, this is brute forced for simplicity, can be optimized later if
# necessary.
# TODO(future PR): if a pattern is matched, add it to "seen" items
# and do not use it in future matching.
for idx, seen_q_op_info in qstate.idx_to_seen_q_op_infos.items():
for fusion_pattern in known_module_fusion_patterns:
cur_fqns = []
cur_seen_q_op_info = seen_q_op_info
is_match = True
for mod_type in fusion_pattern:
if cur_seen_q_op_info is not None and mod_type == cur_seen_q_op_info.type:
cur_fqns.append(cur_seen_q_op_info.fqn)
next_seen_q_op_infos = get_users_of_seen_q_op_info(
qstate.idx_to_seen_q_op_infos, cur_seen_q_op_info)
if len(next_seen_q_op_infos) == 1:
cur_seen_q_op_info = next_seen_q_op_infos[0]
else:
cur_seen_q_op_info = None
continue
else:
is_match = False
break
is_match = pattern_is_match(
fusion_pattern, seen_q_op_info, qstate.idx_to_seen_q_op_infos)
if is_match:
cur_fqns = [seen_q_op_info.fqn]
cur_seen_q_op_info = seen_q_op_info
for _element in fusion_pattern[:-1]:
users = get_users_of_seen_q_op_info(
qstate.idx_to_seen_q_op_infos, cur_seen_q_op_info)
cur_seen_q_op_info = users[0]
cur_fqns.append(cur_seen_q_op_info.fqn)
# we check for existence to ensure the final fusion list
# is deduplicated, in case the same op is called multiple
# times in a single forward

View File

@ -101,6 +101,11 @@ known_module_fusion_patterns = [
(torch.nn.Conv2d, torch.nn.BatchNorm2d),
]
# TODO(future): reuse global mapping
known_function_fusion_patterns_and_replacements = {
(torch.Tensor.add, torch.relu): toq.add_relu,
}
binary_related_ops = (
(torch.add, torch.Tensor.add),
(torch.add, torch.Tensor.add_),

View File

@ -36,6 +36,10 @@ from .utils import (
OpQuantizeabilityType,
)
from .function_fusion import (
match_fusion_patterns,
)
OpConvertInfo = Tuple[
# quantized equivalent of original op (None means keep original)
Optional[Callable],
@ -176,7 +180,7 @@ class AutoQuantizationState(torch.nn.Module):
s += f" {k}: {v}\n"
s += "}\n"
else:
s += "(idx_to_packed_weight_name): {}"
s += "(idx_to_packed_weight_name): {}\n"
if len(self.tensor_id_to_scale_zp):
s += "(tensor_id_to_scale_zp): {\n"
for k, v in self.tensor_id_to_scale_zp.items(): # type: ignore[assignment]
@ -758,7 +762,7 @@ class AutoQuantizationState(torch.nn.Module):
self.idx, op_type, op_type_is_module, fqn, arg_tensor_infos, [],
packable_tensor_idx_to_name, packable_nontensor_idx_to_arg,
packable_tensor_kwarg_name_to_name,
op_packing_only_uses_module_attributes, qconfig)
op_packing_only_uses_module_attributes, qconfig, None)
return args, kwargs
@ -834,6 +838,9 @@ class AutoQuantizationState(torch.nn.Module):
if isinstance(element, torch.Tensor):
_add_output_qtensor_info(element, dtype_to_use)
def match_fusion_patterns(self):
match_fusion_patterns(self.idx_to_seen_q_op_infos)
def _maybe_insert_input_observers(self, seen_q_op_info: SeenQOpInfo):
func_output_dtype_type = get_func_output_dtype_type(seen_q_op_info)
input_observed_arg_idxs = get_input_observed_arg_idxs(

View File

@ -49,6 +49,21 @@ class QTensorInfo:
inf_dtype: torch.dtype # dtype at inference
@dataclasses.dataclass
class FusionInfo:
# linear matched pattern, example: [torch.add, torch.relu]
pattern: Tuple[Callable, ...]
# what the current element should be replaced with during execution
# example: toq.add_relu (for torch.add -> torch.relu)
replacement_type_this_element: Callable
# true if the current element is the first element of the pattern,
# for example true for torch.add in (torch.add -> torch.relu)
is_first_element: bool
# true if the current element is the last element of the pattern,
# for example true for torch.relu in (torch.add -> torch.relu)
is_last_element: bool
@dataclasses.dataclass
class SeenQOpInfo:
idx: int
@ -84,6 +99,8 @@ class SeenQOpInfo:
op_packing_only_uses_module_attributes: bool
# QConfig for the op, can be None
qconfig: QConfigAny
# fusion_info for the op, is None if no fusion is found
fusion_info: Optional[FusionInfo]
def __repr__(self) -> str:
s = f"(type): {self.type}\n"
@ -96,6 +113,8 @@ class SeenQOpInfo:
s += f"\n (packable_nontensor_idx_to_arg): {self.packable_nontensor_idx_to_arg}"
if len(self.packable_tensor_kwarg_name_to_name):
s += f"\n (packable_tensor_kwarg_name_to_name): {self.packable_tensor_kwarg_name_to_name}"
if self.fusion_info:
s += f"\n (fusion_info): {self.fusion_info}"
return s