mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[quant][graphmode][fx] Add support for keeping output quantized for list and dict (#56391)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56391 Previously we only support keeping output quantized for tensor output, this PR adds support for list and dict (values) as well Test Plan: Imported from OSS Reviewed By: vkuzo Differential Revision: D27860327 fbshipit-source-id: e770160ced47a7173abff5505ec620bd2b1a0b01
This commit is contained in:
parent
42f0fe1fe3
commit
94406f77f6
|
|
@ -2155,6 +2155,83 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
self.assertEqual(dequant, 1)
|
||||
self.assertEqual(quant, 1)
|
||||
|
||||
def test_quant_output_always_observed(self):
|
||||
"""
|
||||
If the output is hardcoded to be quantized, ensure that
|
||||
there is always an observer, even if the last non-output node is not
|
||||
quantizeable.
|
||||
"""
|
||||
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')}
|
||||
prepare_custom_config_dict = {'output_quantized_idxs': [0]}
|
||||
data = (torch.randn(4, 1, 4, 4),)
|
||||
|
||||
# non-quantizeable node, quantized output
|
||||
class M1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.identity = torch.nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.identity(x)
|
||||
return x
|
||||
|
||||
m1 = M1()
|
||||
self.checkGraphModeFxOp(
|
||||
m1, data, QuantType.QAT,
|
||||
prepare_expected_node_occurrence={
|
||||
ns.call_module(torch.quantization.FakeQuantize): 1,
|
||||
},
|
||||
expected_node_occurrence={
|
||||
ns.call_function(torch.quantize_per_tensor): 1,
|
||||
},
|
||||
prepare_custom_config_dict=prepare_custom_config_dict,
|
||||
print_debug_info=True)
|
||||
|
||||
# quantizeable node, quantized output
|
||||
class M2(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(1, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
m2 = M2()
|
||||
self.checkGraphModeFxOp(
|
||||
m2, data, QuantType.QAT,
|
||||
prepare_expected_node_occurrence={
|
||||
# one for weights, one for activations
|
||||
ns.call_module(torch.quantization.FakeQuantize): 2,
|
||||
},
|
||||
expected_node_occurrence={
|
||||
ns.call_function(torch.quantize_per_tensor): 1,
|
||||
},
|
||||
prepare_custom_config_dict=prepare_custom_config_dict,
|
||||
print_debug_info=True)
|
||||
|
||||
# quantizeable node, quantized dictionary output
|
||||
class M3(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(1, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return {"output": x}
|
||||
|
||||
m3 = M3()
|
||||
self.checkGraphModeFxOp(
|
||||
m3, data, QuantType.QAT,
|
||||
prepare_expected_node_occurrence={
|
||||
# one for weights, one for activations
|
||||
ns.call_module(torch.quantization.FakeQuantize): 2,
|
||||
},
|
||||
expected_node_occurrence={
|
||||
ns.call_function(torch.quantize_per_tensor): 1,
|
||||
},
|
||||
prepare_custom_config_dict=prepare_custom_config_dict)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
class TestQuantizeFxOps(QuantizationTestCase):
|
||||
"""Unit tests for individual ops
|
||||
|
|
@ -2795,59 +2872,6 @@ class TestQuantizeFxOps(QuantizationTestCase):
|
|||
self.checkGraphModuleNodes(
|
||||
mp, expected_node_occurrence=expected_node_occurrence)
|
||||
|
||||
def test_quant_output_always_observed(self):
|
||||
"""
|
||||
If the output is hardcoded to be quantized, ensure that
|
||||
there is always an observer, even if the last non-output node is not
|
||||
quantizeable.
|
||||
"""
|
||||
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')}
|
||||
prepare_custom_config_dict = {'output_quantized_idxs': [0]}
|
||||
data = (torch.randn(4, 1, 4, 4),)
|
||||
|
||||
# non-quantizeable node, quantized output
|
||||
class M1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.identity = torch.nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.identity(x)
|
||||
return x
|
||||
|
||||
m1 = M1()
|
||||
self.checkGraphModeFxOp(
|
||||
m1, data, QuantType.QAT,
|
||||
prepare_expected_node_occurrence={
|
||||
ns.call_module(torch.quantization.FakeQuantize): 1,
|
||||
},
|
||||
expected_node_occurrence={
|
||||
ns.call_function(torch.quantize_per_tensor): 1,
|
||||
},
|
||||
prepare_custom_config_dict=prepare_custom_config_dict)
|
||||
|
||||
# quantizeable node, quantized output
|
||||
class M2(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(1, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
m2 = M2()
|
||||
self.checkGraphModeFxOp(
|
||||
m2, data, QuantType.QAT,
|
||||
prepare_expected_node_occurrence={
|
||||
# one for weights, one for activations
|
||||
ns.call_module(torch.quantization.FakeQuantize): 2,
|
||||
},
|
||||
expected_node_occurrence={
|
||||
ns.call_function(torch.quantize_per_tensor): 1,
|
||||
},
|
||||
prepare_custom_config_dict=prepare_custom_config_dict)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_cat(self):
|
||||
""" quantization of the output of cat will depend on the
|
||||
|
|
|
|||
|
|
@ -344,6 +344,48 @@ def insert_observer_for_input_arg_of_observed_node(
|
|||
activation_post_process_indexes,
|
||||
env, observed_graph, load_arg, observed_node_names_set, quants)
|
||||
|
||||
def insert_observer_for_output_of_model(
|
||||
node: Node,
|
||||
model: torch.nn.Module,
|
||||
qconfig_map: Dict[str, QConfigAny],
|
||||
activation_post_process_map: Dict[str, List[str]],
|
||||
activation_post_process_indexes: Dict[str, int],
|
||||
env: Dict[Any, Any], observed_graph: Graph, load_arg: Callable,
|
||||
observed_node_names_set: Set[str],
|
||||
quants: Dict[str, List[Tuple[DefaultQuantizeHandler, Callable]]]):
|
||||
if isinstance(node, Node):
|
||||
assert qconfig_map is not None
|
||||
local_qconfig = qconfig_map[node.name]
|
||||
assert local_qconfig is not None, \
|
||||
'qconfig of a node before a quantized output must exist'
|
||||
if node.name not in observed_node_names_set:
|
||||
insert_observer(
|
||||
node, local_qconfig.activation(),
|
||||
model,
|
||||
activation_post_process_map,
|
||||
activation_post_process_indexes,
|
||||
env, observed_graph, load_arg, observed_node_names_set, quants)
|
||||
elif isinstance(node, list) or isinstance(node, tuple):
|
||||
for n in node:
|
||||
insert_observer_for_output_of_model(
|
||||
n,
|
||||
model,
|
||||
qconfig_map,
|
||||
activation_post_process_map,
|
||||
activation_post_process_indexes,
|
||||
env, observed_graph, load_arg, observed_node_names_set, quants)
|
||||
elif isinstance(node, dict):
|
||||
for n in node.values():
|
||||
insert_observer_for_output_of_model(
|
||||
n,
|
||||
model,
|
||||
qconfig_map,
|
||||
activation_post_process_map,
|
||||
activation_post_process_indexes,
|
||||
env, observed_graph, load_arg, observed_node_names_set, quants)
|
||||
else:
|
||||
raise Exception("hardcoding output to be quantized not supported: " + str(type(node)))
|
||||
|
||||
def insert_observers_for_model(
|
||||
model: GraphModule,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
|
|
@ -377,20 +419,13 @@ def insert_observers_for_model(
|
|||
output_node_seen_cnt += 1
|
||||
if cur_output_node_idx in output_quantized_idxs:
|
||||
prev_node = node.args[0]
|
||||
assert isinstance(prev_node, Node), \
|
||||
('hardcoding list/dict outputs to be quantized is ' +
|
||||
'not supported')
|
||||
if prev_node.name not in observed_node_names_set:
|
||||
assert qconfig_map is not None
|
||||
local_qconfig = qconfig_map[prev_node.name]
|
||||
assert local_qconfig is not None, \
|
||||
'qconfig of a node before a quantized output must exist'
|
||||
insert_observer(
|
||||
prev_node, local_qconfig.activation(),
|
||||
model,
|
||||
activation_post_process_map,
|
||||
activation_post_process_indexes,
|
||||
env, observed_graph, load_arg, observed_node_names_set, quants)
|
||||
insert_observer_for_output_of_model(
|
||||
prev_node,
|
||||
model,
|
||||
qconfig_map,
|
||||
activation_post_process_map,
|
||||
activation_post_process_indexes,
|
||||
env, observed_graph, load_arg, observed_node_names_set, quants)
|
||||
|
||||
observed_graph.output(load_arg(node.args[0]))
|
||||
result_node = node
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user