dbr quant function fusion [2/x]: use fusion for observation and inference (#71781)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71781

The previous PR added information about fusions found in the subgraphs.

This PR uses that information for:
1. inserting observers at the end of fusions and not in the middle
2. during inference, replacing the original op with the fused op. The
way this is implemented is that the base op is replaced with the fused op,
and all other ops are replaced with identity functions.

Test Plan:
```
python test/test_quantization.py TestQuantizeDBR.test_fusion_functions
```

Reviewed By: jerryzh168

Differential Revision: D33775097

Pulled By: vkuzo

fbshipit-source-id: 12249b85b2f7ba7545a54872aeb5f1ff2fc928cf
This commit is contained in:
Vasiliy Kuznetsov 2022-02-07 05:55:40 -08:00 committed by Facebook GitHub Bot
parent c1f9f38ca1
commit 0db4324ea9
4 changed files with 105 additions and 14 deletions

View File

@ -432,12 +432,27 @@ class TestQuantizeDBR(QuantizeDBRTestCase):
m = M().eval()
qconfig = torch.quantization.default_qconfig
mp = _quantize_dbr.prepare(m, {'': qconfig}, (torch.randn(1, 1, 1, 1),))
self.assertTrue(
mp._auto_quant_state.idx_to_seen_q_op_infos[0].fusion_info is not None)
self.assertTrue(
mp._auto_quant_state.idx_to_seen_q_op_infos[1].fusion_info is not None)
# TODO(future PR): use fusion results to insert observers
# TODO(future PR): use fusion results to replace function at inference
# verify that the add relu is not observed
self.assertTrue(
'1' not in mp._auto_quant_state.tensor_id_to_observer)
# verify that the relu is observed
self.assertTrue(
'2' in mp._auto_quant_state.tensor_id_to_observer)
mp(torch.randn(1, 1, 1, 1))
mq = _quantize_dbr.convert(mp)
# verify that the add-relu got fused
mqt = torch.jit.trace(mq, (torch.randn(1, 1, 1, 1),))
FileCheck().check_count("quantized::add_relu", 1, exactly=True).run(
mqt.graph)
# TODO(future PR): use information about non-quantizeable ops during
# matching fusion patterns
@ -578,7 +593,6 @@ class TestQuantizeDBR(QuantizeDBRTestCase):
# TODO enable scripting support for this
do_torchscript_checks=False)
@unittest.skip('will be reenabled in the next PR')
def test_method(self):
class M(torch.nn.Module):
def forward(self, x):

View File

@ -5,6 +5,7 @@ from .utils import (
FusionInfo,
SeenQOpInfo,
get_users_of_seen_q_op_info,
get_producer_of_seen_q_op_info,
)
def _identity(x):
@ -30,6 +31,29 @@ def pattern_is_match(
break
return is_match
def get_seen_q_op_info_of_start_of_fusion(
seen_q_op_info_end_of_fusion: SeenQOpInfo,
idx_to_seen_q_op_infos: Dict[int, SeenQOpInfo],
) -> SeenQOpInfo:
assert seen_q_op_info_end_of_fusion.fusion_info is not None
cur_seen_q_op_info = seen_q_op_info_end_of_fusion
for idx in range(len(seen_q_op_info_end_of_fusion.fusion_info.pattern) - 1):
cur_seen_q_op_info = get_producer_of_seen_q_op_info(
idx_to_seen_q_op_infos, cur_seen_q_op_info) # type: ignore[assignment]
return cur_seen_q_op_info
def get_seen_q_op_info_of_end_of_fusion(
seen_q_op_info_start_of_fusion: SeenQOpInfo,
idx_to_seen_q_op_infos: Dict[int, SeenQOpInfo],
) -> SeenQOpInfo:
assert seen_q_op_info_start_of_fusion.fusion_info is not None
cur_seen_q_op_info = seen_q_op_info_start_of_fusion
for idx in range(len(seen_q_op_info_start_of_fusion.fusion_info.pattern) - 1):
users = get_users_of_seen_q_op_info(
idx_to_seen_q_op_infos, cur_seen_q_op_info)
cur_seen_q_op_info = users[0]
return cur_seen_q_op_info
def match_fusion_patterns(
idx_to_seen_q_op_infos: Dict[int, SeenQOpInfo],
):

View File

@ -38,6 +38,8 @@ from .utils import (
from .function_fusion import (
match_fusion_patterns,
get_seen_q_op_info_of_start_of_fusion,
get_seen_q_op_info_of_end_of_fusion,
)
OpConvertInfo = Tuple[
@ -365,13 +367,37 @@ class AutoQuantizationState(torch.nn.Module):
* observe the output, if needed
"""
seen_q_op_info = self._get_cur_seen_q_op_info()
func_output_obs_type = get_func_output_obs_type(seen_q_op_info)
# TODO(future PR): other output types
if func_output_obs_type != FuncOutputObsType.NONE:
seen_q_op_info = self._get_cur_seen_q_op_info()
tensor_id = seen_q_op_info.output_tensor_infos[0].id
obs = self.tensor_id_to_observer[str(tensor_id)]
output = obs(output)
# if we are in a fusion, we only observe at the end of it
is_fusion = seen_q_op_info.fusion_info is not None
is_end_of_fusion = seen_q_op_info.fusion_info is not None and \
seen_q_op_info.fusion_info.is_last_element
if is_fusion:
if is_end_of_fusion:
# do observe in the end of fusions, according to info
# of the base op
seen_q_op_info_start = get_seen_q_op_info_of_start_of_fusion(
seen_q_op_info, self.idx_to_seen_q_op_infos)
# use the obs type from beginning of pattern
func_output_obs_type = get_func_output_obs_type(seen_q_op_info_start)
if func_output_obs_type != FuncOutputObsType.NONE:
# use the output tensor ID from the end of pattern
tensor_id = seen_q_op_info.output_tensor_infos[0].id
obs = self.tensor_id_to_observer[str(tensor_id)]
output = obs(output)
else:
# do not observe in the middle of fusions
pass
else:
# observe without fusions as normal
func_output_obs_type = get_func_output_obs_type(seen_q_op_info)
# TODO(future PR): other output types
if func_output_obs_type != FuncOutputObsType.NONE:
tensor_id = seen_q_op_info.output_tensor_infos[0].id
obs = self.tensor_id_to_observer[str(tensor_id)]
output = obs(output)
if self.log_op_outputs:
output_clone = clone_detach_tensor_without_dispatch(output)
@ -523,7 +549,8 @@ class AutoQuantizationState(torch.nn.Module):
`get_op_convert_info`.
"""
# calculate new op
maybe_new_op = get_quantized_op(seen_q_op_info)
maybe_new_op = get_quantized_op(
seen_q_op_info, self.idx_to_seen_q_op_infos)
# calculate quant infos
arg_quant_infos, arg_dequant_infos, any_arg_quant_or_dequant_needed = \
@ -539,7 +566,17 @@ class AutoQuantizationState(torch.nn.Module):
additional_kwargs = {}
needs_scale_zp = converted_func_needs_scale_zp(seen_q_op_info)
if needs_scale_zp:
output_tensor_infos = seen_q_op_info.output_tensor_infos
cur_seen_q_op_info = seen_q_op_info
# if this is a start of a fusion pattern, get the observer
# from the end of the fusion
is_start_of_fusion = seen_q_op_info.fusion_info and \
seen_q_op_info.fusion_info.is_first_element
if is_start_of_fusion:
cur_seen_q_op_info = get_seen_q_op_info_of_end_of_fusion(
seen_q_op_info, self.idx_to_seen_q_op_infos)
output_tensor_infos = cur_seen_q_op_info.output_tensor_infos
tensor_id = output_tensor_infos[0].id
scale, zp = self.tensor_id_to_scale_zp[tensor_id]
additional_kwargs.update({'scale': scale, 'zero_point': zp})
@ -876,9 +913,20 @@ class AutoQuantizationState(torch.nn.Module):
seen_q_op_info: SeenQOpInfo,
root_module: torch.nn.Module,
):
func_output_obs_type = get_func_output_obs_type(seen_q_op_info)
output_tensor_id = seen_q_op_info.output_tensor_infos[0].id
if seen_q_op_info.fusion_info is not None:
if not seen_q_op_info.fusion_info.is_first_element:
# if we are in a fusion but not at the start, do not insert observer
return
else:
# if we are in a fusion and at the start, insert observer for its end
# get the output of the end of the fusion
cur_seen_q_op_info = get_seen_q_op_info_of_end_of_fusion(
seen_q_op_info, self.idx_to_seen_q_op_infos)
output_tensor_id = cur_seen_q_op_info.output_tensor_infos[0].id
else:
output_tensor_id = seen_q_op_info.output_tensor_infos[0].id
func_output_obs_type = get_func_output_obs_type(seen_q_op_info)
if func_output_obs_type == FuncOutputObsType.NEW_OBS:
# TODO(future PR): check qconfig is None
qconfig = get_cur_qconfig(

View File

@ -369,12 +369,17 @@ def get_op_packing_only_uses_module_attributes(
def get_quantized_op(
seen_q_op_info: SeenQOpInfo,
idx_to_seen_q_op_infos: Dict[int, SeenQOpInfo],
) -> Optional[Callable]:
"""
Given a `seen_q_op_info`, returns the quantized version of the seen function.
If the `seen_q_op_info` corresponds to a module, returns `None`.
If the function does need quantizing, returns `None`.
"""
# if we are in a fusion, use the fusion replacement rules
if seen_q_op_info.fusion_info is not None:
return seen_q_op_info.fusion_info.replacement_type_this_element
op_type = seen_q_op_info.type
is_module = isinstance(op_type, type(torch.nn.Module))
if is_module: