mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
dbr quant function fusion [2/x]: use fusion for observation and inference (#71781)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71781 The previous PR added information about fusions found in the subgraphs. This PR uses that information for: 1. inserting observers at the end of fusions and not in the middle 2. during inference, replacing the original op with the fused op. The way this is implemented is that the base op is replaced with the fused op, and all other ops are replaced with identity functions. Test Plan: ``` python test/test_quantization.py TestQuantizeDBR.test_fusion_functions ``` Reviewed By: jerryzh168 Differential Revision: D33775097 Pulled By: vkuzo fbshipit-source-id: 12249b85b2f7ba7545a54872aeb5f1ff2fc928cf
This commit is contained in:
parent
c1f9f38ca1
commit
0db4324ea9
|
|
@ -432,12 +432,27 @@ class TestQuantizeDBR(QuantizeDBRTestCase):
|
|||
m = M().eval()
|
||||
qconfig = torch.quantization.default_qconfig
|
||||
mp = _quantize_dbr.prepare(m, {'': qconfig}, (torch.randn(1, 1, 1, 1),))
|
||||
|
||||
self.assertTrue(
|
||||
mp._auto_quant_state.idx_to_seen_q_op_infos[0].fusion_info is not None)
|
||||
self.assertTrue(
|
||||
mp._auto_quant_state.idx_to_seen_q_op_infos[1].fusion_info is not None)
|
||||
# TODO(future PR): use fusion results to insert observers
|
||||
# TODO(future PR): use fusion results to replace function at inference
|
||||
|
||||
# verify that the add relu is not observed
|
||||
self.assertTrue(
|
||||
'1' not in mp._auto_quant_state.tensor_id_to_observer)
|
||||
# verify that the relu is observed
|
||||
self.assertTrue(
|
||||
'2' in mp._auto_quant_state.tensor_id_to_observer)
|
||||
|
||||
mp(torch.randn(1, 1, 1, 1))
|
||||
mq = _quantize_dbr.convert(mp)
|
||||
|
||||
# verify that the add-relu got fused
|
||||
mqt = torch.jit.trace(mq, (torch.randn(1, 1, 1, 1),))
|
||||
FileCheck().check_count("quantized::add_relu", 1, exactly=True).run(
|
||||
mqt.graph)
|
||||
|
||||
# TODO(future PR): use information about non-quantizeable ops during
|
||||
# matching fusion patterns
|
||||
|
||||
|
|
@ -578,7 +593,6 @@ class TestQuantizeDBR(QuantizeDBRTestCase):
|
|||
# TODO enable scripting support for this
|
||||
do_torchscript_checks=False)
|
||||
|
||||
@unittest.skip('will be reenabled in the next PR')
|
||||
def test_method(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from .utils import (
|
|||
FusionInfo,
|
||||
SeenQOpInfo,
|
||||
get_users_of_seen_q_op_info,
|
||||
get_producer_of_seen_q_op_info,
|
||||
)
|
||||
|
||||
def _identity(x):
|
||||
|
|
@ -30,6 +31,29 @@ def pattern_is_match(
|
|||
break
|
||||
return is_match
|
||||
|
||||
def get_seen_q_op_info_of_start_of_fusion(
|
||||
seen_q_op_info_end_of_fusion: SeenQOpInfo,
|
||||
idx_to_seen_q_op_infos: Dict[int, SeenQOpInfo],
|
||||
) -> SeenQOpInfo:
|
||||
assert seen_q_op_info_end_of_fusion.fusion_info is not None
|
||||
cur_seen_q_op_info = seen_q_op_info_end_of_fusion
|
||||
for idx in range(len(seen_q_op_info_end_of_fusion.fusion_info.pattern) - 1):
|
||||
cur_seen_q_op_info = get_producer_of_seen_q_op_info(
|
||||
idx_to_seen_q_op_infos, cur_seen_q_op_info) # type: ignore[assignment]
|
||||
return cur_seen_q_op_info
|
||||
|
||||
def get_seen_q_op_info_of_end_of_fusion(
|
||||
seen_q_op_info_start_of_fusion: SeenQOpInfo,
|
||||
idx_to_seen_q_op_infos: Dict[int, SeenQOpInfo],
|
||||
) -> SeenQOpInfo:
|
||||
assert seen_q_op_info_start_of_fusion.fusion_info is not None
|
||||
cur_seen_q_op_info = seen_q_op_info_start_of_fusion
|
||||
for idx in range(len(seen_q_op_info_start_of_fusion.fusion_info.pattern) - 1):
|
||||
users = get_users_of_seen_q_op_info(
|
||||
idx_to_seen_q_op_infos, cur_seen_q_op_info)
|
||||
cur_seen_q_op_info = users[0]
|
||||
return cur_seen_q_op_info
|
||||
|
||||
def match_fusion_patterns(
|
||||
idx_to_seen_q_op_infos: Dict[int, SeenQOpInfo],
|
||||
):
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ from .utils import (
|
|||
|
||||
from .function_fusion import (
|
||||
match_fusion_patterns,
|
||||
get_seen_q_op_info_of_start_of_fusion,
|
||||
get_seen_q_op_info_of_end_of_fusion,
|
||||
)
|
||||
|
||||
OpConvertInfo = Tuple[
|
||||
|
|
@ -365,10 +367,34 @@ class AutoQuantizationState(torch.nn.Module):
|
|||
* observe the output, if needed
|
||||
"""
|
||||
seen_q_op_info = self._get_cur_seen_q_op_info()
|
||||
|
||||
# if we are in a fusion, we only observe at the end of it
|
||||
is_fusion = seen_q_op_info.fusion_info is not None
|
||||
is_end_of_fusion = seen_q_op_info.fusion_info is not None and \
|
||||
seen_q_op_info.fusion_info.is_last_element
|
||||
|
||||
if is_fusion:
|
||||
if is_end_of_fusion:
|
||||
# do observe in the end of fusions, according to info
|
||||
# of the base op
|
||||
seen_q_op_info_start = get_seen_q_op_info_of_start_of_fusion(
|
||||
seen_q_op_info, self.idx_to_seen_q_op_infos)
|
||||
# use the obs type from beginning of pattern
|
||||
func_output_obs_type = get_func_output_obs_type(seen_q_op_info_start)
|
||||
if func_output_obs_type != FuncOutputObsType.NONE:
|
||||
# use the output tensor ID from the end of pattern
|
||||
tensor_id = seen_q_op_info.output_tensor_infos[0].id
|
||||
obs = self.tensor_id_to_observer[str(tensor_id)]
|
||||
output = obs(output)
|
||||
|
||||
else:
|
||||
# do not observe in the middle of fusions
|
||||
pass
|
||||
else:
|
||||
# observe without fusions as normal
|
||||
func_output_obs_type = get_func_output_obs_type(seen_q_op_info)
|
||||
# TODO(future PR): other output types
|
||||
if func_output_obs_type != FuncOutputObsType.NONE:
|
||||
seen_q_op_info = self._get_cur_seen_q_op_info()
|
||||
tensor_id = seen_q_op_info.output_tensor_infos[0].id
|
||||
obs = self.tensor_id_to_observer[str(tensor_id)]
|
||||
output = obs(output)
|
||||
|
|
@ -523,7 +549,8 @@ class AutoQuantizationState(torch.nn.Module):
|
|||
`get_op_convert_info`.
|
||||
"""
|
||||
# calculate new op
|
||||
maybe_new_op = get_quantized_op(seen_q_op_info)
|
||||
maybe_new_op = get_quantized_op(
|
||||
seen_q_op_info, self.idx_to_seen_q_op_infos)
|
||||
|
||||
# calculate quant infos
|
||||
arg_quant_infos, arg_dequant_infos, any_arg_quant_or_dequant_needed = \
|
||||
|
|
@ -539,7 +566,17 @@ class AutoQuantizationState(torch.nn.Module):
|
|||
additional_kwargs = {}
|
||||
needs_scale_zp = converted_func_needs_scale_zp(seen_q_op_info)
|
||||
if needs_scale_zp:
|
||||
output_tensor_infos = seen_q_op_info.output_tensor_infos
|
||||
cur_seen_q_op_info = seen_q_op_info
|
||||
|
||||
# if this is a start of a fusion pattern, get the observer
|
||||
# from the end of the fusion
|
||||
is_start_of_fusion = seen_q_op_info.fusion_info and \
|
||||
seen_q_op_info.fusion_info.is_first_element
|
||||
if is_start_of_fusion:
|
||||
cur_seen_q_op_info = get_seen_q_op_info_of_end_of_fusion(
|
||||
seen_q_op_info, self.idx_to_seen_q_op_infos)
|
||||
|
||||
output_tensor_infos = cur_seen_q_op_info.output_tensor_infos
|
||||
tensor_id = output_tensor_infos[0].id
|
||||
scale, zp = self.tensor_id_to_scale_zp[tensor_id]
|
||||
additional_kwargs.update({'scale': scale, 'zero_point': zp})
|
||||
|
|
@ -876,9 +913,20 @@ class AutoQuantizationState(torch.nn.Module):
|
|||
seen_q_op_info: SeenQOpInfo,
|
||||
root_module: torch.nn.Module,
|
||||
):
|
||||
func_output_obs_type = get_func_output_obs_type(seen_q_op_info)
|
||||
if seen_q_op_info.fusion_info is not None:
|
||||
if not seen_q_op_info.fusion_info.is_first_element:
|
||||
# if we are in a fusion but not at the start, do not insert observer
|
||||
return
|
||||
else:
|
||||
# if we are in a fusion and at the start, insert observer for its end
|
||||
# get the output of the end of the fusion
|
||||
cur_seen_q_op_info = get_seen_q_op_info_of_end_of_fusion(
|
||||
seen_q_op_info, self.idx_to_seen_q_op_infos)
|
||||
output_tensor_id = cur_seen_q_op_info.output_tensor_infos[0].id
|
||||
else:
|
||||
output_tensor_id = seen_q_op_info.output_tensor_infos[0].id
|
||||
|
||||
func_output_obs_type = get_func_output_obs_type(seen_q_op_info)
|
||||
if func_output_obs_type == FuncOutputObsType.NEW_OBS:
|
||||
# TODO(future PR): check qconfig is None
|
||||
qconfig = get_cur_qconfig(
|
||||
|
|
|
|||
|
|
@ -369,12 +369,17 @@ def get_op_packing_only_uses_module_attributes(
|
|||
|
||||
def get_quantized_op(
|
||||
seen_q_op_info: SeenQOpInfo,
|
||||
idx_to_seen_q_op_infos: Dict[int, SeenQOpInfo],
|
||||
) -> Optional[Callable]:
|
||||
"""
|
||||
Given a `seen_q_op_info`, returns the quantized version of the seen function.
|
||||
If the `seen_q_op_info` corresponds to a module, returns `None`.
|
||||
If the function does need quantizing, returns `None`.
|
||||
"""
|
||||
# if we are in a fusion, use the fusion replacement rules
|
||||
if seen_q_op_info.fusion_info is not None:
|
||||
return seen_q_op_info.fusion_info.replacement_type_this_element
|
||||
|
||||
op_type = seen_q_op_info.type
|
||||
is_module = isinstance(op_type, type(torch.nn.Module))
|
||||
if is_module:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user