mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[quant][pt2] Handle literal conv args in convert QAT (#103731)
Summary: Similar to the prepare case, we need to manually copy over literal conv args such as padding and stride to the new, replaced conv nodes, since these args are not captured by the subgraph rewriter. Test Plan: python test/test_quantization.py TestQuantizePT2E.test_qat_conv_bn_fusion_literal_args Reviewed By: jerryzh168 Differential Revision: D46383130 Pull Request resolved: https://github.com/pytorch/pytorch/pull/103731 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
08a054649c
commit
2bc56bec07
|
|
@ -340,7 +340,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
|||
is_per_channel: bool,
|
||||
has_relu: bool,
|
||||
has_bias: bool = True,
|
||||
expected_conv_constant_args: Optional[Tuple[Any, ...]] = None,
|
||||
expected_conv_literal_args: Optional[Tuple[Any, ...]] = None,
|
||||
):
|
||||
"""
|
||||
Verify that the graph module matches the fused QAT [conv - bn (- relu)] pattern
|
||||
|
|
@ -400,13 +400,13 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
|||
self.assertEqual(conv_node.target, torch.ops.aten.convolution.default)
|
||||
self.assertEqual(scale_factor_reshape_node.target, torch.ops.aten.view.default)
|
||||
|
||||
# Verify: conv constant args
|
||||
if expected_conv_constant_args is not None:
|
||||
# Verify: conv literal args
|
||||
if expected_conv_literal_args is not None:
|
||||
assert (
|
||||
len(expected_conv_constant_args) == 6
|
||||
len(expected_conv_literal_args) == 6
|
||||
), "wrong num conv args, bad test setup"
|
||||
for i in range(6):
|
||||
self.assertEqual(conv_node.args[i + 3], expected_conv_constant_args[i])
|
||||
self.assertEqual(conv_node.args[i + 3], expected_conv_literal_args[i])
|
||||
|
||||
# Verify: conv input activation fake quantize
|
||||
conv_input_fq_node = conv_node.args[0]
|
||||
|
|
@ -1364,7 +1364,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
m, example_inputs, is_per_channel=True, has_relu=False
|
||||
)
|
||||
|
||||
def test_prepare_qat_conv_bn_fusion_constant_args(self):
|
||||
def test_qat_conv_bn_fusion_literal_args(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -1384,20 +1384,20 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
example_inputs,
|
||||
is_per_channel=False,
|
||||
has_relu=False,
|
||||
expected_conv_constant_args=conv_args,
|
||||
expected_conv_literal_args=conv_args,
|
||||
)
|
||||
self._verify_symmetric_qnnpack_qat_graph(
|
||||
M(),
|
||||
example_inputs,
|
||||
is_per_channel=True,
|
||||
has_relu=False,
|
||||
expected_conv_constant_args=conv_args,
|
||||
expected_conv_literal_args=conv_args,
|
||||
)
|
||||
self._verify_symmetric_qnnpack_qat_numerics(
|
||||
M(), example_inputs, is_per_channel=False
|
||||
M(), example_inputs, is_per_channel=False, verify_convert=True,
|
||||
)
|
||||
self._verify_symmetric_qnnpack_qat_numerics(
|
||||
M(), example_inputs, is_per_channel=True
|
||||
M(), example_inputs, is_per_channel=True, verify_convert=True,
|
||||
)
|
||||
|
||||
def test_qat_conv_bn_fusion_no_conv_bias(self):
|
||||
|
|
|
|||
|
|
@ -374,6 +374,43 @@ def _get_conv_bn_getitem_nodes(nodes: List[Node]) -> Tuple[Node, Node, Node]:
|
|||
assert getitem_node is not None
|
||||
return (conv_node, bn_node, getitem_node)
|
||||
|
||||
def _filter_nodes_map(nodes_map: Dict[Node, Node]) -> Dict[Node, Node]:
|
||||
"""
|
||||
Return a filtered `nodes_map` returned from the subgraph rewriter.
|
||||
The filtered `nodes_map` will contain only nodes that are actually
|
||||
matched in the pattern, excluding None or placeholder nodes.
|
||||
"""
|
||||
new_nodes_map: Dict[Node, Node] = {}
|
||||
for pattern_node, graph_node in nodes_map.items():
|
||||
# bias can be None
|
||||
if graph_node is None:
|
||||
continue
|
||||
# skip pattern placeholder nodes
|
||||
if pattern_node.op == "placeholder":
|
||||
continue
|
||||
new_nodes_map[pattern_node] = graph_node
|
||||
return new_nodes_map
|
||||
|
||||
def _copy_over_literal_conv_args(original_node: Node, new_node: Node):
|
||||
"""
|
||||
Copy over literal args in conv, such as stride and padding, from the matched node
|
||||
in the original graph to its replacement in the new graph.
|
||||
|
||||
This is needed due to the following limitation in the subgraph rewriter when used
|
||||
with dynamo export: literal (non-tensor) args are not supported in the match and
|
||||
replacement patterns. This is because dynamo export automatically inlines these
|
||||
literal args, making them dead placeholder nodes. In the future, we should check
|
||||
if dynamo export can optionally disable this inlining, or if subgraph rewriter
|
||||
can do the copying for us. See https://github.com/pytorch/pytorch/issues/100419.
|
||||
|
||||
Note: Unlike other tensor args like conv weights and biases, literal args are
|
||||
preserved in the original nodes after replacement, so we can access them here.
|
||||
"""
|
||||
assert original_node.target == torch.ops.aten.convolution.default
|
||||
assert new_node.target == torch.ops.aten.convolution.default
|
||||
# x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
|
||||
new_node.args = new_node.args[:3] + original_node.args[3:]
|
||||
|
||||
def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||
"""
|
||||
Given a graph of decomposed aten ops, replace the (conv + bn) pattern with
|
||||
|
|
@ -430,8 +467,8 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
|||
# (1) Copy over metadata from original subgraph. This ensures the stack traces
|
||||
# and annotations are preserved in the new subgraph
|
||||
#
|
||||
# (2) Copy over constant args for conv from the original subgraph
|
||||
# TODO: do this for constant args for batchnorm as well
|
||||
# (2) Copy over literal args for conv from the original subgraph
|
||||
# TODO: do this for literal args for batchnorm as well
|
||||
#
|
||||
# In the future, we should try to push as much of this functionality into the
|
||||
# subgraph rewriter as possible, so we don't have to manually copy anything over.
|
||||
|
|
@ -442,34 +479,17 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
|||
_get_conv_bn_getitem_nodes(r.replacements)
|
||||
|
||||
# Copy over metadata for all three nodes in [conv - bn - getitem]
|
||||
# Also copy over constant args for conv
|
||||
for match_pattern_node, original_node in r.nodes_map.items():
|
||||
# bias can be None
|
||||
if original_node is None:
|
||||
continue
|
||||
# We expect the subgraph rewriter to erase the non-literal args of the matched nodes.
|
||||
# However, this is not done for placeholder nodes, since these nodes do not need to
|
||||
# be replaced. Here we filter out these placeholder nodes since they do not need
|
||||
# metadata copying. E.g. we want to filter out `getitem_placeholder` in this pattern:
|
||||
#
|
||||
# getitem_placeholder -> conv -> bn -> getitem
|
||||
#
|
||||
if any(isinstance(a, Node) for a in original_node.args):
|
||||
continue
|
||||
# Also copy over literal args for conv
|
||||
for match_pattern_node, original_node in _filter_nodes_map(r.nodes_map).items():
|
||||
if original_node.target == torch.ops.aten.convolution.default:
|
||||
_copy_over_literal_conv_args(original_node, replacement_conv_node)
|
||||
replacement_conv_node.meta = original_node.meta
|
||||
# Note: Unlike other tensor args like conv weights and biases, literal args are
|
||||
# preserved in the original nodes after replacement, so we can access them here
|
||||
# x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
|
||||
replacement_conv_node.args = replacement_conv_node.args[:3] + original_node.args[3:]
|
||||
|
||||
# original annotation is referring to the node object in the graph
|
||||
# after rewrite we'll need to update this mapping (input_qspec_map)
|
||||
# update quantization_annotation
|
||||
original_input_qspec_map = original_node.meta["quantization_annotation"].input_qspec_map
|
||||
if "quantization_annotation" not in original_node.meta:
|
||||
continue
|
||||
|
||||
input_qspec_map = {}
|
||||
# get the list of configs, it should be ordered as input, weight, bias
|
||||
# note: this is really hacky, we need a better solution, hopefully
|
||||
|
|
@ -482,7 +502,6 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
|||
# bias
|
||||
if len(replacement_conv_node.args) > 2 and len(all_configs) > 2:
|
||||
input_qspec_map[replacement_conv_node.args[2]] = all_configs[2][1]
|
||||
|
||||
replacement_conv_node.meta["quantization_annotation"].input_qspec_map = input_qspec_map
|
||||
if original_node.target == torch.ops.aten._native_batch_norm_legit.default:
|
||||
replacement_bn_node.meta = original_node.meta
|
||||
|
|
@ -575,6 +594,11 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
|||
# fold bn weights into conv
|
||||
_fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
|
||||
|
||||
# Copy over literal args for conv
|
||||
for _, original_node in _filter_nodes_map(r.nodes_map).items():
|
||||
if original_node.target == torch.ops.aten.convolution.default:
|
||||
_copy_over_literal_conv_args(original_node, conv_node)
|
||||
|
||||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
return m
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user