[quant][pt2e] Support conv bn fusion in convert step for QAT flow (#100442)

Summary:
This PR adds support for folding bn weights into conv for QAT flow, this is equivalent
to the QAT branch of `from_float` in eager mode quantized conv module: https://github.com/pytorch/pytorch/blob/main/torch/ao/nn/quantized/modules/conv.py#L223

Items that needs followup:
* there are some workaround I did because quantize_per_tensor is using float/int args and dynamo does not support these args, need to fix after we change the quantized model representation and also change these args to Tensor

Test Plan: buck2 test @//mode/opt //caffe2/test:quantization_pt2e -- --exact 'caffe2/test:quantization_pt2e - test_convert_qat_conv_bn_fusion (quantization.pt2e.test_quantize_pt2e.TestQuantizePT2E)'

Reviewed By: andrewor14

Differential Revision: D45344281

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100442
Approved by: https://github.com/kimishpatel
This commit is contained in:
Jerry Zhang 2023-05-09 19:43:51 +00:00 committed by PyTorch MergeBot
parent f92b3e1477
commit c3f3cb5b0f
6 changed files with 294 additions and 69 deletions

View File

@ -21,6 +21,7 @@ from torch.ao.quantization._pt2e.quantizer import (
)
from torch.ao.quantization._quantize_pt2e import (
convert_pt2e,
_convert_to_reference_decomposed_fx,
prepare_pt2e_quantizer,
prepare_qat_pt2e_quantizer,
)
@ -32,9 +33,9 @@ from torch.ao.quantization.qconfig import (
default_symmetric_qnnpack_qat_qconfig,
)
from torch.ao.quantization.quantize_fx import (
convert_to_reference_fx,
prepare_fx,
prepare_qat_fx,
convert_to_reference_fx,
)
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
@ -44,9 +45,6 @@ from torch.testing._internal.common_quantization import (
)
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.ao.quantization.quantize_fx import _convert_to_reference_decomposed_fx
@skipIfNoQNNPACK
class TestQuantizePT2E(QuantizationTestCase):
def test_simple_quantizer(self):
@ -599,6 +597,7 @@ class TestQuantizePT2E(QuantizationTestCase):
model: torch.nn.Module,
example_inputs: Tuple[Any, ...],
is_per_channel: bool,
verify_convert: bool = False,
):
"""
Helper method to verify that the QAT numerics for PT2E quantization match those of
@ -615,7 +614,7 @@ class TestQuantizePT2E(QuantizationTestCase):
aten_graph=True,
)
model_pt2e = prepare_qat_pt2e_quantizer(model_pt2e, quantizer)
result_pt2e = model_pt2e(*example_inputs)
after_prepare_result_pt2e = model_pt2e(*example_inputs)
# FX
# Note: In order to match the PT2E numerics exactly, we need to feed the
@ -632,11 +631,36 @@ class TestQuantizePT2E(QuantizationTestCase):
qconfig_mapping = QConfigMapping().set_global(default_qconfig)
backend_config = get_qnnpack_backend_config()
model_fx = prepare_qat_fx(model_fx, qconfig_mapping, example_inputs, backend_config=backend_config)
result_fx = model_fx(*example_inputs)
after_prepare_result_fx = model_fx(*example_inputs)
# Verify that numerics match
self.assertEqual(result_pt2e, result_fx)
self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)
if verify_convert:
model_pt2e = convert_pt2e(model_pt2e)
quant_result_pt2e = model_pt2e(*example_inputs)
model_fx = _convert_to_reference_decomposed_fx(model_fx, backend_config=backend_config)
quant_result_fx = model_fx(*example_inputs)
self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)
def test_convert_qat_conv_bn_numerics(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
example_inputs = (torch.randn(1, 3, 5, 5),)
self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=False)
# TODO: enable in a separate PR
# self._verify_symmetric_qnnpack_qat_numerics(M(), example_inputs, is_per_channel=True)
class TestQuantizePT2EModels(QuantizationTestCase):
@skip_if_no_torchvision

View File

@ -1284,6 +1284,7 @@ def get_fake_value(node, tx):
unimplemented("guard on data-dependent symbolic int/float")
elif isinstance(cause, torch.utils._sympy.value_ranges.ValueRangeError):
raise UserError(UserErrorType.CONSTRAIN_VIOLATION, e.args[0]) from e
# why don't we print the exception here?
raise TorchRuntimeError() from e

View File

@ -6,9 +6,10 @@ import torch
from torch.fx import GraphModule, Node
from torch.fx.subgraph_rewriter import _replace_pattern
import torch.nn.functional as F
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from .utils import _fold_bn_weights_into_conv_node
# Example inputs for both `_conv2d_bn_pattern` and `_fused_qat_conv2d_bn_pattern`
# Example inputs for both `_conv2d_bn_pattern` and `_qat_conv2d_bn_pattern`
_conv2d_bn_pattern_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
@ -19,6 +20,23 @@ _conv2d_bn_pattern_example_inputs = (
torch.randn(1), # bn_running_var
)
# Example inputs for both `_quantized_qat_conv2d_bn_pattern` and `_folded_quantized_qat_conv2d_bn_pattern`
_quantized_conv2d_bn_pattern_example_inputs = (
torch.randn(1, 1, 3, 3).to(torch.int8), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
torch.tensor([1], dtype=torch.float), # input_scale
torch.tensor([0], dtype=torch.int), # input_zero_point
torch.tensor([1], dtype=torch.float), # weight_scale
torch.tensor([0], dtype=torch.int), # weight_zero_point
torch.tensor([1], dtype=torch.float), # output_scale
torch.tensor([0], dtype=torch.int), # output_zero_point
)
def _conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
@ -32,7 +50,7 @@ def _conv2d_bn_pattern(
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True)
return x
def _fused_qat_conv2d_bn_pattern(
def _qat_conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
@ -62,6 +80,96 @@ def _fused_qat_conv2d_bn_pattern(
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
return x
def _quantized_qat_conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
bn_running_var: torch.Tensor,
input_scale: torch.Tensor,
input_zero_point: torch.Tensor,
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
output_scale: torch.Tensor,
output_zero_point: torch.Tensor,
) -> torch.Tensor:
"""
Quantized version of qat conv bn pattern,
This is based on `nniqat.ConvBn2d._forward_approximate`.
used in qat convert, we first match this pattern and then replace it with
normal conv - bn pattern and then fold the weights of bn into conv
"""
# TODO: allow setting eps
bn_eps = 1e-5
weight_quant_min = -127
weight_quant_max = 127
input_quant_min = -128
input_quant_max = 127
output_quant_min = -128
output_quant_max = 127
running_std = torch.sqrt(bn_running_var + bn_eps)
scale_factor = bn_weight / running_std
weight_shape = [1] * len(conv_weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(conv_weight.shape)
bias_shape[1] = -1
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, input_quant_min, input_quant_max, torch.int8)
zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype)
scaled_weight = torch.ops.quantized_decomposed.quantize_per_tensor(
scaled_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
scaled_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
scaled_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
x = F.conv2d(x, scaled_weight, zero_bias)
x = x / scale_factor.reshape(bias_shape)
x = x + conv_bias.reshape(bias_shape)
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, output_scale, output_zero_point, output_quant_min, output_quant_max, torch.int8)
return x
def _folded_quantized_qat_conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
bn_running_var: torch.Tensor,
input_scale: torch.Tensor,
input_zero_point: torch.Tensor,
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
output_scale: torch.Tensor,
output_zero_point: torch.Tensor,
) -> torch.Tensor:
""" Quantized QAT conv - bn pattern with bn weights being folded into conv
"""
# TODO: allow setting eps
bn_eps = 1e-5
weight_quant_min = -127
weight_quant_max = 127
input_quant_min = -128
input_quant_max = 127
output_quant_min = -128
output_quant_max = 127
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, input_quant_min, input_quant_max, torch.int8)
conv_weight = torch.ops.quantized_decomposed.quantize_per_tensor(
conv_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
conv_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
conv_weight, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
x = F.conv2d(x, conv_weight, conv_bias)
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, output_scale, output_zero_point, output_quant_min, output_quant_max, torch.int8)
return x
def _get_aten_graph_module(
pattern: Callable,
example_inputs: Tuple[Any, ...],
@ -94,7 +202,7 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
m.recompile()
example_inputs = _conv2d_bn_pattern_example_inputs
match_pattern = _get_aten_graph_module(_conv2d_bn_pattern, example_inputs)
replacement_pattern = _get_aten_graph_module(_fused_qat_conv2d_bn_pattern, example_inputs)
replacement_pattern = _get_aten_graph_module(_qat_conv2d_bn_pattern, example_inputs)
# TODO: use the public replace_pattern API once it also returns replacement nodes
match_and_replacement = _replace_pattern(m, match_pattern, replacement_pattern)
m.recompile()
@ -127,3 +235,60 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
if original_node.target == operator.getitem:
replacement_getitem_node.meta = original_node.meta
return m
def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
"""
Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv.
"""
m.graph.eliminate_dead_code()
m.recompile()
example_inputs = _quantized_conv2d_bn_pattern_example_inputs
match_pattern = _get_aten_graph_module(_quantized_qat_conv2d_bn_pattern, example_inputs)
# Workaround: current convert does not produce q/dq ops with a specific overload
# we'll remove the overload from the pattern here as a workaround since we do not want to break BC
for n in match_pattern.graph.nodes:
if n.op == "call_function" and n.target == torch.ops.quantized_decomposed.quantize_per_tensor.tensor:
n.target = torch.ops.quantized_decomposed.quantize_per_tensor
if n.op == "call_function" and n.target == torch.ops.quantized_decomposed.dequantize_per_tensor.tensor:
n.target = torch.ops.quantized_decomposed.dequantize_per_tensor
replacement_pattern = _get_aten_graph_module(_folded_quantized_qat_conv2d_bn_pattern, example_inputs)
# TODO: use the public replace_pattern API once it also returns replacement nodes
match_and_replacement = _replace_pattern(m, match_pattern, replacement_pattern, ignore_literals=True)
m.recompile()
for mr in match_and_replacement:
# Find replacement conv and bn nodes by climbing upwards from anchor node
assert len(mr.replacements) == 1, "expected only one replacement node"
# find conv, bn, weight, bias nodes in the graph
replacement_quantize_node = mr.replacements[0]
assert replacement_quantize_node.target == torch.ops.quantized_decomposed.quantize_per_tensor.tensor
n = replacement_quantize_node
conv_node = None
bn_node = None
while conv_node is None or bn_node is None:
if n.target == torch.ops.aten.convolution.default:
conv_node = n
if n.target == torch.ops.aten._native_batch_norm_legit.default:
bn_node = n
assert isinstance(n.args[0], Node)
n = n.args[0]
assert conv_node is not None and bn_node is not None
conv_weight_dq = conv_node.args[1]
assert conv_weight_dq.target == torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
conv_weight_q = conv_weight_dq.args[0]
assert conv_weight_q.target == torch.ops.quantized_decomposed.quantize_per_tensor.tensor
conv_weight = conv_weight_q.args[0]
assert conv_weight.op == "get_attr"
conv_bias = conv_node.args[2]
# fold bn weights into conv
_fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
m.graph.eliminate_dead_code()
m.recompile()
return m

View File

@ -1,5 +1,8 @@
import torch
from torch.fx import GraphModule
from torch.fx import (
GraphModule,
Node,
)
from torch.nn.utils.fusion import fuse_conv_bn_weights
# TODO[jerryzh168]: move this to a more general util function
from torch.ao.quantization.fx.prepare import (
@ -15,68 +18,94 @@ def _get_tensor_constant_from_node(node, m):
assert node.op == "get_attr"
return getattr(m, node.target)
def _get_all_arguments(orig_args, orig_kwargs, args_schema):
all_args = []
for i, schema in enumerate(args_schema):
if schema.name in orig_kwargs:
all_args.append(orig_kwargs[schema.name])
elif not schema.kwarg_only and i < len(orig_args):
all_args.append(orig_args[i])
else:
all_args.append(schema.default_value)
return all_args
def _fold_bn_weights_into_conv_node(
conv_node: Node,
conv_weight_node: Node,
conv_bias_node: Node,
bn_node: Node,
m: GraphModule
) -> None:
# conv weight
conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
# conv bias
conv_b = _get_tensor_constant_from_node(conv_bias_node, m)
transpose = conv_node.args[6]
bn_args_schema = bn_node.target._schema.arguments # type: ignore[union-attr]
bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema)
# bn weight
bn_w = _get_tensor_constant_from_node(bn_args[1], m)
# bn bias
bn_b = _get_tensor_constant_from_node(bn_args[2], m)
# bn running mean
bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
# bn running variance
bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
bn_eps = bn_args[6]
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose)
# update the weight and bias for conv
conv_args = list(conv_node.args)
# calling data since the fused_weight and fused_bias are nn.Parameter
weight_attr_name = conv_weight_node.target
setattr(m, weight_attr_name, fused_weight) # type: ignore[arg-type]
if conv_bias_node is not None:
bias_attr_name = conv_bias_node.target
else:
bias_attr_name = weight_attr_name + "_bias"
with m.graph.inserting_before(conv_node):
get_bias_node = m.graph.get_attr(bias_attr_name)
# NOTE: here we assume the bias of conv is not quantized!
conv_args[2] = get_bias_node
setattr(m, bias_attr_name, fused_bias) # type: ignore[arg-type]
conv_node.args = tuple(conv_args)
# native_batch_norm has 3 outputs, we expect getitem calls on the output
# and we want to replace the uses of getitem 0 with the output of conv
#
# Before:
# conv -> bn - (first output) -> users1
# \ - (second output) -> users2
# \ - (third output) -> users3
# After:
# conv -> (first output) -> users1
# bn -
# \ - (second output) -> users2
# \ - (third output) -> users3
# if users2 and users3 are empty then bn will be removed through dead code elimination
for user in bn_node.users:
if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0:
continue
user.replace_all_uses_with(conv_node)
# fuse conv bn weights, inplace modification of the graph_module and graph
def _fuse_conv_bn_(m: GraphModule) -> None:
for n in m.graph.nodes:
if n.op != "call_function" or n.target != torch.ops.aten._native_batch_norm_legit_no_training.default:
continue
bn_op = n
n = bn_op.args[0]
bn_node = n
n = bn_node.args[0]
if n.op != "call_function" or n.target != torch.ops.aten.convolution.default:
continue
conv_op = n
conv_node = n
conv_weight_node = conv_node.args[1]
conv_bias_node = conv_node.args[2]
_fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m)
# conv weight
conv_w = _get_tensor_constant_from_node(conv_op.args[1], m)
# conv bias
conv_b = _get_tensor_constant_from_node(conv_op.args[2], m)
transpose = conv_op.args[6]
# bn weight
bn_w = _get_tensor_constant_from_node(bn_op.args[1], m)
# bn bias
bn_b = _get_tensor_constant_from_node(bn_op.args[2], m)
# bn running mean
bn_rm = _get_tensor_constant_from_node(bn_op.args[3], m)
# bn running variance
bn_rv = _get_tensor_constant_from_node(bn_op.args[4], m)
bn_eps = bn_op.args[6]
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose)
# update the weight and bias for conv
conv_args = list(conv_op.args)
# calling data since the fused_weight and fused_bias are nn.Parameter
weight_attr_name = conv_args[1].target
setattr(m, weight_attr_name, fused_weight)
if conv_args[2] is not None:
bias_attr_name = conv_args[2].target
else:
bias_attr_name = weight_attr_name + "_bias"
with m.graph.inserting_before(conv_op):
get_bias_node = m.graph.get_attr(bias_attr_name)
conv_args[2] = get_bias_node
setattr(m, bias_attr_name, fused_bias)
conv_op.args = tuple(conv_args)
# native_batch_norm has 3 outputs, we expect getitem calls on the output
# and we want to replace the uses of getitem 0 with the output of conv
#
# Before:
# conv -> bn - (first output) -> users1
# \ - (second output) -> users2
# \ - (third output) -> users3
# After:
# conv -> (first output) -> users1
# bn -
# \ - (second output) -> users2
# \ - (third output) -> users3
# if users2 and users3 are empty then bn will be removed through dead code elimination
for user in bn_op.users:
if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0:
continue
user.replace_all_uses_with(conv_op)
m.graph.eliminate_dead_code()
m.recompile()

View File

@ -2,7 +2,10 @@ from torch.fx import GraphModule
from ._pt2e.prepare import prepare
from ._pt2e._propagate_annotation import propagate_annotation
from ._pt2e.qat_utils import _fuse_conv_bn_qat
from ._pt2e.qat_utils import (
_fuse_conv_bn_qat,
_fold_conv_bn_qat,
)
from ._pt2e.utils import (
_get_node_name_to_scope,
_fuse_conv_bn_,
@ -81,4 +84,6 @@ def prepare_qat_pt2e_quantizer(
def convert_pt2e(
model: GraphModule
):
return _convert_to_reference_decomposed_fx(model)
model = _convert_to_reference_decomposed_fx(model) # type: ignore[assignment]
model = _fold_conv_bn_qat(model)
return model

View File

@ -222,6 +222,7 @@ def _replace_pattern(
pattern: Union[Callable, GraphModule],
replacement: Union[Callable, GraphModule],
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None, # type: ignore[name-defined]
ignore_literals: bool = False,
) -> List[ReplacedPatterns]:
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch
@ -243,7 +244,7 @@ def _replace_pattern(
replacement_graph = symbolic_trace(replacement).graph
matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False,
remove_overlapping_matches=True)
remove_overlapping_matches=True, ignore_literals=ignore_literals)
_matches: List[InternalMatch] = matcher.match(original_graph)
# Filter out matches that don't match the filter