[quant][graphmode][fx] Add support for keeping output quantized for list and dict (#56391)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56391

Previously we only support keeping output quantized for tensor output, this PR adds support
for list and dict (values) as well

Test Plan: Imported from OSS

Reviewed By: vkuzo

Differential Revision: D27860327

fbshipit-source-id: e770160ced47a7173abff5505ec620bd2b1a0b01
This commit is contained in:
Jerry Zhang 2021-04-19 21:36:01 -07:00 committed by Facebook GitHub Bot
parent 42f0fe1fe3
commit 94406f77f6
2 changed files with 126 additions and 67 deletions

View File

@ -2155,6 +2155,83 @@ class TestQuantizeFx(QuantizationTestCase):
self.assertEqual(dequant, 1)
self.assertEqual(quant, 1)
def test_quant_output_always_observed(self):
"""
If the output is hardcoded to be quantized, ensure that
there is always an observer, even if the last non-output node is not
quantizeable.
"""
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')}
prepare_custom_config_dict = {'output_quantized_idxs': [0]}
data = (torch.randn(4, 1, 4, 4),)
# non-quantizeable node, quantized output
class M1(torch.nn.Module):
def __init__(self):
super().__init__()
self.identity = torch.nn.Identity()
def forward(self, x):
x = self.identity(x)
return x
m1 = M1()
self.checkGraphModeFxOp(
m1, data, QuantType.QAT,
prepare_expected_node_occurrence={
ns.call_module(torch.quantization.FakeQuantize): 1,
},
expected_node_occurrence={
ns.call_function(torch.quantize_per_tensor): 1,
},
prepare_custom_config_dict=prepare_custom_config_dict,
print_debug_info=True)
# quantizeable node, quantized output
class M2(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv(x)
return x
m2 = M2()
self.checkGraphModeFxOp(
m2, data, QuantType.QAT,
prepare_expected_node_occurrence={
# one for weights, one for activations
ns.call_module(torch.quantization.FakeQuantize): 2,
},
expected_node_occurrence={
ns.call_function(torch.quantize_per_tensor): 1,
},
prepare_custom_config_dict=prepare_custom_config_dict,
print_debug_info=True)
# quantizeable node, quantized dictionary output
class M3(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv(x)
return {"output": x}
m3 = M3()
self.checkGraphModeFxOp(
m3, data, QuantType.QAT,
prepare_expected_node_occurrence={
# one for weights, one for activations
ns.call_module(torch.quantization.FakeQuantize): 2,
},
expected_node_occurrence={
ns.call_function(torch.quantize_per_tensor): 1,
},
prepare_custom_config_dict=prepare_custom_config_dict)
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
@ -2795,59 +2872,6 @@ class TestQuantizeFxOps(QuantizationTestCase):
self.checkGraphModuleNodes(
mp, expected_node_occurrence=expected_node_occurrence)
def test_quant_output_always_observed(self):
"""
If the output is hardcoded to be quantized, ensure that
there is always an observer, even if the last non-output node is not
quantizeable.
"""
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')}
prepare_custom_config_dict = {'output_quantized_idxs': [0]}
data = (torch.randn(4, 1, 4, 4),)
# non-quantizeable node, quantized output
class M1(torch.nn.Module):
def __init__(self):
super().__init__()
self.identity = torch.nn.Identity()
def forward(self, x):
x = self.identity(x)
return x
m1 = M1()
self.checkGraphModeFxOp(
m1, data, QuantType.QAT,
prepare_expected_node_occurrence={
ns.call_module(torch.quantization.FakeQuantize): 1,
},
expected_node_occurrence={
ns.call_function(torch.quantize_per_tensor): 1,
},
prepare_custom_config_dict=prepare_custom_config_dict)
# quantizeable node, quantized output
class M2(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv(x)
return x
m2 = M2()
self.checkGraphModeFxOp(
m2, data, QuantType.QAT,
prepare_expected_node_occurrence={
# one for weights, one for activations
ns.call_module(torch.quantization.FakeQuantize): 2,
},
expected_node_occurrence={
ns.call_function(torch.quantize_per_tensor): 1,
},
prepare_custom_config_dict=prepare_custom_config_dict)
@skipIfNoFBGEMM
def test_cat(self):
""" quantization of the output of cat will depend on the

View File

@ -344,6 +344,48 @@ def insert_observer_for_input_arg_of_observed_node(
activation_post_process_indexes,
env, observed_graph, load_arg, observed_node_names_set, quants)
def insert_observer_for_output_of_model(
node: Node,
model: torch.nn.Module,
qconfig_map: Dict[str, QConfigAny],
activation_post_process_map: Dict[str, List[str]],
activation_post_process_indexes: Dict[str, int],
env: Dict[Any, Any], observed_graph: Graph, load_arg: Callable,
observed_node_names_set: Set[str],
quants: Dict[str, List[Tuple[DefaultQuantizeHandler, Callable]]]):
if isinstance(node, Node):
assert qconfig_map is not None
local_qconfig = qconfig_map[node.name]
assert local_qconfig is not None, \
'qconfig of a node before a quantized output must exist'
if node.name not in observed_node_names_set:
insert_observer(
node, local_qconfig.activation(),
model,
activation_post_process_map,
activation_post_process_indexes,
env, observed_graph, load_arg, observed_node_names_set, quants)
elif isinstance(node, list) or isinstance(node, tuple):
for n in node:
insert_observer_for_output_of_model(
n,
model,
qconfig_map,
activation_post_process_map,
activation_post_process_indexes,
env, observed_graph, load_arg, observed_node_names_set, quants)
elif isinstance(node, dict):
for n in node.values():
insert_observer_for_output_of_model(
n,
model,
qconfig_map,
activation_post_process_map,
activation_post_process_indexes,
env, observed_graph, load_arg, observed_node_names_set, quants)
else:
raise Exception("hardcoding output to be quantized not supported: " + str(type(node)))
def insert_observers_for_model(
model: GraphModule,
modules: Dict[str, torch.nn.Module],
@ -377,20 +419,13 @@ def insert_observers_for_model(
output_node_seen_cnt += 1
if cur_output_node_idx in output_quantized_idxs:
prev_node = node.args[0]
assert isinstance(prev_node, Node), \
('hardcoding list/dict outputs to be quantized is ' +
'not supported')
if prev_node.name not in observed_node_names_set:
assert qconfig_map is not None
local_qconfig = qconfig_map[prev_node.name]
assert local_qconfig is not None, \
'qconfig of a node before a quantized output must exist'
insert_observer(
prev_node, local_qconfig.activation(),
model,
activation_post_process_map,
activation_post_process_indexes,
env, observed_graph, load_arg, observed_node_names_set, quants)
insert_observer_for_output_of_model(
prev_node,
model,
qconfig_map,
activation_post_process_map,
activation_post_process_indexes,
env, observed_graph, load_arg, observed_node_names_set, quants)
observed_graph.output(load_arg(node.args[0]))
result_node = node