[quant][pt2] Fix no conv bias in convert QAT (#103298)

Summary:
Previously, the QAT pattern for conv + bn with no conv
bias was not actually replaced in convert. This commit adds an
extra pattern in the convert path for this case and the numerics
now match FX's.

Test Plan: python test/test_quantization.py TestQuantizePT2E.test_prepare_qat_conv_bn_fusion_no_conv_bias

Reviewed By: jerryzh168

Differential Revision: D46382819

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103298
Approved by: https://github.com/jerryzh168
This commit is contained in:
Andrew Or 2023-06-16 01:59:48 +00:00 committed by PyTorch MergeBot
parent a52b6f086d
commit dad29f906b
4 changed files with 56 additions and 36 deletions

View File

@ -1400,12 +1400,11 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
M(), example_inputs, is_per_channel=True
)
def test_prepare_qat_conv_bn_fusion_no_conv_bias(self):
def test_qat_conv_bn_fusion_no_conv_bias(self):
class M2(torch.nn.Module):
"""
Mixed conv + BN with and without conv bias.
"""
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3, bias=False)
@ -1423,25 +1422,25 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m1 = TestHelperModules.ConvWithBNRelu(relu=False, bias=False)
example_inputs = (torch.randn(3, 3, 5, 5),)
self._verify_symmetric_qnnpack_qat_graph(
m1, example_inputs, is_per_channel=False, has_relu=False, has_bias=False
m1, example_inputs, is_per_channel=False, has_relu=False, has_bias=False,
)
m1 = TestHelperModules.ConvWithBNRelu(relu=False, bias=False)
self._verify_symmetric_qnnpack_qat_graph(
m1, example_inputs, is_per_channel=True, has_relu=False, has_bias=False
m1, example_inputs, is_per_channel=True, has_relu=False, has_bias=False,
)
m1 = TestHelperModules.ConvWithBNRelu(relu=False, bias=False)
self._verify_symmetric_qnnpack_qat_numerics(
m1, example_inputs, is_per_channel=False
m1, example_inputs, is_per_channel=False, verify_convert=True,
)
m1 = TestHelperModules.ConvWithBNRelu(relu=False, bias=False)
self._verify_symmetric_qnnpack_qat_numerics(
m1, example_inputs, is_per_channel=True
m1, example_inputs, is_per_channel=True, verify_convert=True,
)
self._verify_symmetric_qnnpack_qat_numerics(
M2(), example_inputs, is_per_channel=False
M2(), example_inputs, is_per_channel=False, verify_convert=True,
)
self._verify_symmetric_qnnpack_qat_numerics(
M2(), example_inputs, is_per_channel=True
M2(), example_inputs, is_per_channel=True, verify_convert=True,
)
def test_prepare_qat_conv_bn_relu_fusion(self):

View File

@ -446,6 +446,7 @@ class TestQuantizePT2EFXX86Inductor(QuantizationTestCase):
self.assertEqual(ref_result, inductor_res, atol=5e-2, rtol=5e-2)
@skipIfNoX86
@unittest.skip("Fails due to small numerics mismatch, reenable this with the new API in the future")
def test_inductor_qconv_lowering(self):
dim_to_module = {
1: nn.Conv1d,

View File

@ -1,7 +1,7 @@
import copy
import itertools
import operator
from typing import Any, Callable, List, Tuple
from typing import Any, Callable, Dict, List, Tuple
import torch
from torch.fx import Graph, GraphModule, Node
@ -25,14 +25,32 @@ _conv2d_bn_pattern_example_inputs = (
_quantized_conv2d_bn_pattern_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
_weight_scale = torch.tensor([1], dtype=torch.float)
_weight_zero_point = torch.tensor([0], dtype=torch.int)
def _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(
is_per_channel: bool,
has_bias: bool,
) -> Dict[str, Any]:
"""
Optional example inputs for both `_quantized_qat_conv2d_bn_pattern`
and `_folded_quantized_qat_conv2d_bn_pattern`, expressed as kwargs.
Note that weight_scale and weight_zero_point are only used when
`is_per_channel` is True. This is because for per tensor quantization,
scale and zero point are hard coded into quantize/dequantize ops
in the pattern.
"""
kwargs = {}
if is_per_channel:
kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float)
kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int)
if has_bias:
kwargs["conv_bias"] = torch.randn(1)
return kwargs
def _conv2d_bn_pattern(
x: torch.Tensor,
@ -47,6 +65,7 @@ def _conv2d_bn_pattern(
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True)
return x
# TODO: merge this with the `no_conv_bias` case
def _qat_conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
@ -152,7 +171,7 @@ def _get_input_output_quantized_filter():
return _input_output_quantized_filter
def _get_quantized_qat_conv2d_bn_pattern(is_per_channel: bool, has_relu: bool):
def _get_quantized_qat_conv2d_bn_pattern(is_per_channel: bool, has_relu: bool, has_bias: bool):
"""
Return the quantized version of QAT conv + BN pattern.
This is based on `nniqat.ConvBn2d._forward_approximate`,
@ -169,7 +188,6 @@ def _get_quantized_qat_conv2d_bn_pattern(is_per_channel: bool, has_relu: bool):
def _quantized_qat_conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
@ -183,7 +201,6 @@ def _get_quantized_qat_conv2d_bn_pattern(is_per_channel: bool, has_relu: bool):
bias_shape = [1] * len(conv_weight.shape)
bias_shape[1] = -1
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype)
if is_per_channel:
scaled_weight = torch.ops.quantized_decomposed.quantize_per_channel(
scaled_weight, kwargs['weight_scale'], kwargs['weight_zero_point'], per_channel_axis,
@ -200,16 +217,21 @@ def _get_quantized_qat_conv2d_bn_pattern(is_per_channel: bool, has_relu: bool):
scaled_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
scaled_weight, 1.0, int(0), weight_quant_min, weight_quant_max, torch.int8,
)
x = F.conv2d(x, scaled_weight, zero_bias)
if has_bias:
zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype)
x = F.conv2d(x, scaled_weight, zero_bias)
else:
x = F.conv2d(x, scaled_weight, None)
x = x / scale_factor.reshape(bias_shape)
x = x + conv_bias.reshape(bias_shape)
if has_bias:
x = x + kwargs["conv_bias"].reshape(bias_shape)
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
if has_relu:
x = F.relu(x)
return x
return _quantized_qat_conv2d_bn_pattern
def _get_folded_quantized_qat_conv2d_bn_pattern(is_per_channel: bool, has_relu: bool):
def _get_folded_quantized_qat_conv2d_bn_pattern(is_per_channel: bool, has_relu: bool, has_bias: bool):
"""
Quantized QAT conv - bn pattern with bn weights being folded into conv.
"""
@ -222,7 +244,6 @@ def _get_folded_quantized_qat_conv2d_bn_pattern(is_per_channel: bool, has_relu:
def _folded_quantized_qat_conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
@ -245,7 +266,10 @@ def _get_folded_quantized_qat_conv2d_bn_pattern(is_per_channel: bool, has_relu:
conv_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
conv_weight, 1.0, int(0), weight_quant_min, weight_quant_max, torch.int8,
)
x = F.conv2d(x, conv_weight, conv_bias)
if has_bias:
x = F.conv2d(x, conv_weight, kwargs["conv_bias"])
else:
x = F.conv2d(x, conv_weight, None)
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
if has_relu:
x = F.relu(x)
@ -478,20 +502,15 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
replacement_options = itertools.product(
[True, False], # is_per_channel
[True, False], # has_relu
[True, False], # has_bias
)
for is_per_channel, has_relu in replacement_options:
for is_per_channel, has_relu, has_bias in replacement_options:
example_inputs = _quantized_conv2d_bn_pattern_example_inputs
kwargs_args = {}
# Note that weight_scale and weight_zero_point are only used when is_per_channel is True
# This is because for per tensor quantization, scale and zero point are hard coded
# into quantize/dequantize ops in the pattern.
if is_per_channel:
kwargs_args['weight_scale'] = _weight_scale
kwargs_args['weight_zero_point'] = _weight_zero_point
match_pattern = _get_quantized_qat_conv2d_bn_pattern(is_per_channel, has_relu)
match_pattern = _get_aten_graph_module(match_pattern, example_inputs, **kwargs_args)
replacement_pattern = _get_folded_quantized_qat_conv2d_bn_pattern(is_per_channel, has_relu)
replacement_pattern = _get_aten_graph_module(replacement_pattern, example_inputs, **kwargs_args)
kwargs = _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(is_per_channel, has_bias)
match_pattern = _get_quantized_qat_conv2d_bn_pattern(is_per_channel, has_relu, has_bias)
match_pattern = _get_aten_graph_module(match_pattern, example_inputs, **kwargs)
replacement_pattern = _get_folded_quantized_qat_conv2d_bn_pattern(is_per_channel, has_relu, has_bias)
replacement_pattern = _get_aten_graph_module(replacement_pattern, example_inputs, **kwargs)
replacements.extend(
replace_pattern_with_filters(
m,
@ -526,7 +545,7 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
assert isinstance(conv_weight, Node)
assert conv_weight.op == "get_attr"
conv_bias = conv_node.args[2]
assert isinstance(conv_bias, Node)
assert conv_bias is None or isinstance(conv_bias, Node)
(weight_q_node, weight_dq_node) = _get_fused_convbn_q_dq_nodes(r.replacements)
original_weight_q_node = None

View File

@ -9,7 +9,7 @@ from torch.ao.quantization.fx.prepare import (
_is_activation_post_process_node,
)
import operator
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple
def _get_tensor_constant_from_node(node, m):
@ -32,7 +32,7 @@ def _get_all_arguments(orig_args, orig_kwargs, args_schema):
def _fold_bn_weights_into_conv_node(
conv_node: Node,
conv_weight_node: Node,
conv_bias_node: Node,
conv_bias_node: Optional[Node],
bn_node: Node,
m: GraphModule
) -> None:
@ -63,7 +63,8 @@ def _fold_bn_weights_into_conv_node(
conv_args = list(conv_node.args)
# calling data since the fused_weight and fused_bias are nn.Parameter
weight_attr_name = conv_weight_node.target
setattr(m, weight_attr_name, fused_weight) # type: ignore[arg-type]
assert isinstance(weight_attr_name, str)
setattr(m, weight_attr_name, fused_weight)
if conv_bias_node is not None:
bias_attr_name = conv_bias_node.target
else: