mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
dbr quant function fusion [1/x]: record matches for functions (#71764)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71764 For DBR quant, adds the code for matching seen ops to function fusion patterns. After we have the full DAG, we have a separate pass over the dag and add matched fusion patterns to the seen op data structure. This is the first PR in the stack which implements matching and recording the match results. Future PRs in this stack will use the match results to modify observer insertion and inference. Test Plan: ``` python test/test_quantization.py TestQuantizeDBR.test_fusion_functions ``` Reviewed By: jerryzh168 Differential Revision: D33775098 Pulled By: vkuzo fbshipit-source-id: 488aac902bf568d41c863ee49248990411ed9c53
This commit is contained in:
parent
b0d48a8e66
commit
4ad1ca1abc
|
|
@ -422,6 +422,25 @@ class TestQuantizeDBR(QuantizeDBRTestCase):
|
|||
qconfig = torch.quantization.default_qconfig
|
||||
self._test_auto_tracing(m, qconfig, (torch.randn(1, 1, 2, 2),))
|
||||
|
||||
def test_fusion_functions(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x = x + x
|
||||
x = torch.relu(x)
|
||||
return x
|
||||
|
||||
m = M().eval()
|
||||
qconfig = torch.quantization.default_qconfig
|
||||
mp = _quantize_dbr.prepare(m, {'': qconfig}, (torch.randn(1, 1, 1, 1),))
|
||||
self.assertTrue(
|
||||
mp._auto_quant_state.idx_to_seen_q_op_infos[0].fusion_info is not None)
|
||||
self.assertTrue(
|
||||
mp._auto_quant_state.idx_to_seen_q_op_infos[1].fusion_info is not None)
|
||||
# TODO(future PR): use fusion results to insert observers
|
||||
# TODO(future PR): use fusion results to replace function at inference
|
||||
# TODO(future PR): use information about non-quantizeable ops during
|
||||
# matching fusion patterns
|
||||
|
||||
def test_observers_not_touched_by_tracing(self):
|
||||
"""
|
||||
Verifies that running dynamic tracing does not change any data
|
||||
|
|
|
|||
|
|
@ -381,6 +381,7 @@ def add_auto_observation(
|
|||
if first_call:
|
||||
for _, v in self.named_modules():
|
||||
if hasattr(v, '_auto_quant_state'):
|
||||
v._auto_quant_state.match_fusion_patterns()
|
||||
v._auto_quant_state.insert_observers(v)
|
||||
|
||||
return output
|
||||
|
|
|
|||
77
torch/ao/quantization/_dbr/function_fusion.py
Normal file
77
torch/ao/quantization/_dbr/function_fusion.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
from typing import Dict, Tuple, Callable, Optional
|
||||
|
||||
from .mappings import known_function_fusion_patterns_and_replacements
|
||||
from .utils import (
|
||||
FusionInfo,
|
||||
SeenQOpInfo,
|
||||
get_users_of_seen_q_op_info,
|
||||
)
|
||||
|
||||
def _identity(x):
|
||||
return x
|
||||
|
||||
def pattern_is_match(
|
||||
fusion_pattern: Tuple[Callable, ...],
|
||||
cur_seen_q_op_info: Optional[SeenQOpInfo],
|
||||
idx_to_seen_q_op_infos: Dict[int, SeenQOpInfo],
|
||||
) -> bool:
|
||||
is_match = True
|
||||
for el_type in fusion_pattern:
|
||||
if cur_seen_q_op_info is not None and el_type == cur_seen_q_op_info.type:
|
||||
next_seen_q_op_infos = get_users_of_seen_q_op_info(
|
||||
idx_to_seen_q_op_infos, cur_seen_q_op_info)
|
||||
if len(next_seen_q_op_infos) == 1:
|
||||
cur_seen_q_op_info = next_seen_q_op_infos[0]
|
||||
else:
|
||||
cur_seen_q_op_info = None
|
||||
continue
|
||||
else:
|
||||
is_match = False
|
||||
break
|
||||
return is_match
|
||||
|
||||
def match_fusion_patterns(
|
||||
idx_to_seen_q_op_infos: Dict[int, SeenQOpInfo],
|
||||
):
|
||||
"""
|
||||
Matches fusion patterns to elements of `idx_to_seen_q_op_infos`.
|
||||
Modifies them inplace if matches are found.
|
||||
|
||||
Note:
|
||||
1. The matching is local to the ops seen by a single parent module,
|
||||
it does not cross module boundaries. This is for simplicity, and
|
||||
there are no plans to relax this at the moment.
|
||||
2. The matching only supports linear patterns of ops where all of
|
||||
of the arguments needed to execute the fusion are passed to the first
|
||||
op in the sequence. This is for simplicity, and can be relaxed
|
||||
in a future PR if there is a need.
|
||||
3. Currently the matching does not look at non quantizeable ops,
|
||||
this will be fixed in the next PR.
|
||||
"""
|
||||
|
||||
# Walk the subgraphs and find the function fusions. For now, this is
|
||||
# brute forced for simplicity, can be optimized later if necessary.
|
||||
for idx, seen_q_op_info in idx_to_seen_q_op_infos.items():
|
||||
for fusion_pattern, replacement in \
|
||||
known_function_fusion_patterns_and_replacements.items():
|
||||
is_match = pattern_is_match(
|
||||
fusion_pattern, seen_q_op_info, idx_to_seen_q_op_infos)
|
||||
if not is_match:
|
||||
continue
|
||||
|
||||
cur_seen_q_op_info = seen_q_op_info
|
||||
for idx in range(len(fusion_pattern)):
|
||||
if idx > 0:
|
||||
users = get_users_of_seen_q_op_info(
|
||||
idx_to_seen_q_op_infos, cur_seen_q_op_info)
|
||||
cur_seen_q_op_info = users[0]
|
||||
|
||||
is_first_element = idx == 0
|
||||
is_last_element = idx == len(fusion_pattern) - 1
|
||||
replacement_type = replacement if is_first_element \
|
||||
else _identity
|
||||
fusion_info = FusionInfo(
|
||||
fusion_pattern, replacement_type, is_first_element,
|
||||
is_last_element)
|
||||
cur_seen_q_op_info.fusion_info = fusion_info
|
||||
break
|
||||
|
|
@ -2,6 +2,8 @@ from typing import List
|
|||
|
||||
import torch
|
||||
|
||||
from .function_fusion import pattern_is_match
|
||||
|
||||
from .utils import (
|
||||
get_users_of_seen_q_op_info,
|
||||
)
|
||||
|
|
@ -19,8 +21,6 @@ def get_module_fusion_fqns(
|
|||
Walks the subgraphs and determines which modules should be
|
||||
fused.
|
||||
|
||||
TODO: test coverage
|
||||
|
||||
Output: a list of FQNs of modules which should be fused.
|
||||
"""
|
||||
results = []
|
||||
|
|
@ -32,25 +32,21 @@ def get_module_fusion_fqns(
|
|||
# Walk the subgraphs and record the FQNs of all known module fusions.
|
||||
# For now, this is brute forced for simplicity, can be optimized later if
|
||||
# necessary.
|
||||
# TODO(future PR): if a pattern is matched, add it to "seen" items
|
||||
# and do not use it in future matching.
|
||||
for idx, seen_q_op_info in qstate.idx_to_seen_q_op_infos.items():
|
||||
for fusion_pattern in known_module_fusion_patterns:
|
||||
cur_fqns = []
|
||||
cur_seen_q_op_info = seen_q_op_info
|
||||
is_match = True
|
||||
for mod_type in fusion_pattern:
|
||||
if cur_seen_q_op_info is not None and mod_type == cur_seen_q_op_info.type:
|
||||
cur_fqns.append(cur_seen_q_op_info.fqn)
|
||||
next_seen_q_op_infos = get_users_of_seen_q_op_info(
|
||||
qstate.idx_to_seen_q_op_infos, cur_seen_q_op_info)
|
||||
if len(next_seen_q_op_infos) == 1:
|
||||
cur_seen_q_op_info = next_seen_q_op_infos[0]
|
||||
else:
|
||||
cur_seen_q_op_info = None
|
||||
continue
|
||||
else:
|
||||
is_match = False
|
||||
break
|
||||
is_match = pattern_is_match(
|
||||
fusion_pattern, seen_q_op_info, qstate.idx_to_seen_q_op_infos)
|
||||
if is_match:
|
||||
cur_fqns = [seen_q_op_info.fqn]
|
||||
cur_seen_q_op_info = seen_q_op_info
|
||||
for _element in fusion_pattern[:-1]:
|
||||
users = get_users_of_seen_q_op_info(
|
||||
qstate.idx_to_seen_q_op_infos, cur_seen_q_op_info)
|
||||
cur_seen_q_op_info = users[0]
|
||||
cur_fqns.append(cur_seen_q_op_info.fqn)
|
||||
|
||||
# we check for existence to ensure the final fusion list
|
||||
# is deduplicated, in case the same op is called multiple
|
||||
# times in a single forward
|
||||
|
|
|
|||
|
|
@ -101,6 +101,11 @@ known_module_fusion_patterns = [
|
|||
(torch.nn.Conv2d, torch.nn.BatchNorm2d),
|
||||
]
|
||||
|
||||
# TODO(future): reuse global mapping
|
||||
known_function_fusion_patterns_and_replacements = {
|
||||
(torch.Tensor.add, torch.relu): toq.add_relu,
|
||||
}
|
||||
|
||||
binary_related_ops = (
|
||||
(torch.add, torch.Tensor.add),
|
||||
(torch.add, torch.Tensor.add_),
|
||||
|
|
|
|||
|
|
@ -36,6 +36,10 @@ from .utils import (
|
|||
OpQuantizeabilityType,
|
||||
)
|
||||
|
||||
from .function_fusion import (
|
||||
match_fusion_patterns,
|
||||
)
|
||||
|
||||
OpConvertInfo = Tuple[
|
||||
# quantized equivalent of original op (None means keep original)
|
||||
Optional[Callable],
|
||||
|
|
@ -176,7 +180,7 @@ class AutoQuantizationState(torch.nn.Module):
|
|||
s += f" {k}: {v}\n"
|
||||
s += "}\n"
|
||||
else:
|
||||
s += "(idx_to_packed_weight_name): {}"
|
||||
s += "(idx_to_packed_weight_name): {}\n"
|
||||
if len(self.tensor_id_to_scale_zp):
|
||||
s += "(tensor_id_to_scale_zp): {\n"
|
||||
for k, v in self.tensor_id_to_scale_zp.items(): # type: ignore[assignment]
|
||||
|
|
@ -758,7 +762,7 @@ class AutoQuantizationState(torch.nn.Module):
|
|||
self.idx, op_type, op_type_is_module, fqn, arg_tensor_infos, [],
|
||||
packable_tensor_idx_to_name, packable_nontensor_idx_to_arg,
|
||||
packable_tensor_kwarg_name_to_name,
|
||||
op_packing_only_uses_module_attributes, qconfig)
|
||||
op_packing_only_uses_module_attributes, qconfig, None)
|
||||
|
||||
return args, kwargs
|
||||
|
||||
|
|
@ -834,6 +838,9 @@ class AutoQuantizationState(torch.nn.Module):
|
|||
if isinstance(element, torch.Tensor):
|
||||
_add_output_qtensor_info(element, dtype_to_use)
|
||||
|
||||
def match_fusion_patterns(self):
|
||||
match_fusion_patterns(self.idx_to_seen_q_op_infos)
|
||||
|
||||
def _maybe_insert_input_observers(self, seen_q_op_info: SeenQOpInfo):
|
||||
func_output_dtype_type = get_func_output_dtype_type(seen_q_op_info)
|
||||
input_observed_arg_idxs = get_input_observed_arg_idxs(
|
||||
|
|
|
|||
|
|
@ -49,6 +49,21 @@ class QTensorInfo:
|
|||
inf_dtype: torch.dtype # dtype at inference
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FusionInfo:
|
||||
# linear matched pattern, example: [torch.add, torch.relu]
|
||||
pattern: Tuple[Callable, ...]
|
||||
# what the current element should be replaced with during execution
|
||||
# example: toq.add_relu (for torch.add -> torch.relu)
|
||||
replacement_type_this_element: Callable
|
||||
# true if the current element is the first element of the pattern,
|
||||
# for example true for torch.add in (torch.add -> torch.relu)
|
||||
is_first_element: bool
|
||||
# true if the current element is the last element of the pattern,
|
||||
# for example true for torch.relu in (torch.add -> torch.relu)
|
||||
is_last_element: bool
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SeenQOpInfo:
|
||||
idx: int
|
||||
|
|
@ -84,6 +99,8 @@ class SeenQOpInfo:
|
|||
op_packing_only_uses_module_attributes: bool
|
||||
# QConfig for the op, can be None
|
||||
qconfig: QConfigAny
|
||||
# fusion_info for the op, is None if no fusion is found
|
||||
fusion_info: Optional[FusionInfo]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
s = f"(type): {self.type}\n"
|
||||
|
|
@ -96,6 +113,8 @@ class SeenQOpInfo:
|
|||
s += f"\n (packable_nontensor_idx_to_arg): {self.packable_nontensor_idx_to_arg}"
|
||||
if len(self.packable_tensor_kwarg_name_to_name):
|
||||
s += f"\n (packable_tensor_kwarg_name_to_name): {self.packable_tensor_kwarg_name_to_name}"
|
||||
if self.fusion_info:
|
||||
s += f"\n (fusion_info): {self.fusion_info}"
|
||||
return s
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user