[ao] fixing quantized prelu workflow (#103455)

Summary: https://github.com/pytorch/pytorch/issues/100654 noticed prelu
was not running its observers when the quantization flow was being run,
this was a bug which is now fixed and the relevant prelu tests also now
check for this. Also added a corrected observer for PReLU to
qconfig_mapping

Test Plan: python test/test_quantization.py TestStaticQuantizedModule.test_prelu

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103455
Approved by: https://github.com/jerryzh168
This commit is contained in:
HDCharles 2023-06-22 11:05:47 -07:00 committed by PyTorch MergeBot
parent 8a500f0be6
commit 8176cd8c0f
4 changed files with 48 additions and 22 deletions

View File

@ -1314,29 +1314,38 @@ class TestStaticQuantizedModule(QuantizationTestCase):
offsets, set_qconfig, is_emb_bag=True, dtype=qdtype)
def test_prelu(self):
x = torch.randn((4, 4, 4, 4), dtype=torch.float)
qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8)
for num_parameters in range(1, 10):
x = torch.randn(4, num_parameters, 4)
qx = torch.quantize_per_tensor_dynamic(x, dtype=torch.quint8, reduce_range=False)
# num_parameters = 1
prelu_module = nnq.PReLU(output_scale=1.0, output_zero_point=0, num_parameters=1)
w = torch.randn(1, dtype=torch.float)
qw = torch.quantize_per_tensor(w, 1.0, 0, dtype=torch.quint8)
prelu_module.set_weight(qw)
qy = prelu_module(qx)
qy_ref = torch.prelu(qx, qw)
self.assertEqual(qy_ref, qy,
msg="PReLU module API failed")
f_prelu = torch.nn.PReLU(num_parameters=num_parameters)
f_prelu.weight = torch.nn.Parameter(torch.randn(num_parameters).abs())
f_prelu.qconfig = torch.ao.quantization.QConfig(
activation=torch.ao.quantization.default_observer,
weight=torch.ao.quantization.default_observer,)
f_prelu.activation_post_process = f_prelu.qconfig.activation()
f_prelu.activation_post_process(f_prelu(x))
q_prelu = nnq.PReLU.from_float(f_prelu)
w_obs = f_prelu.qconfig.weight()
w_obs(f_prelu.weight)
w_scale, w_zp = w_obs.calculate_qparams()
q_prelu_weight = torch.quantize_per_tensor(
f_prelu.weight,
dtype=torch.quint8,
scale=w_scale,
zero_point=w_zp
).dequantize()
# num_parameters = num_channels
prelu_module = nnq.PReLU(output_scale=1.0, output_zero_point=0, num_parameters=4)
w = torch.randn(4, dtype=torch.float)
qw = torch.quantize_per_tensor(w, 1.0, 0, dtype=torch.quint8)
prelu_module.set_weight(qw)
qy = prelu_module(qx)
qy_ref = torch.prelu(qx, qw)
self.assertEqual(qy_ref, qy,
msg="PReLU module API failed")
# check that the weight makes sense
self.assertEqual(q_prelu.weight.dequantize(), q_prelu_weight)
f_prelu.weight = torch.nn.Parameter(q_prelu.weight.dequantize())
qy = q_prelu(qx)
qy_ref = torch.quantize_per_tensor(
f_prelu(qx.dequantize()), q_prelu.scale, q_prelu.zero_point, dtype=torch.quint8
)
# check that the output makes sense
self.assertEqual(qy, qy_ref, atol=.1, rtol=.1)
def test_channel_shuffle(self):
"""Tests the correctness of the ChannelShuffle module.

View File

@ -1,5 +1,5 @@
import torch
from warnings import warn
__all__ = [
"ReLU6",
"Hardswish",
@ -271,6 +271,11 @@ class PReLU(torch.nn.Module):
qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
float_wt = mod.weight.float()
observer = mod.qconfig.weight()
observer(float_wt)
if observer.dtype != torch.quint8:
warn(
f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
)
wt_scale, wt_zp = observer.calculate_qparams()
qweight = torch.quantize_per_tensor(
float_wt, float(wt_scale), int(wt_zp), torch.quint8)
@ -282,6 +287,11 @@ class PReLU(torch.nn.Module):
qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
float_wt = mod.weight.float()
observer = mod.qconfig.weight()
observer(float_wt)
if observer.dtype != torch.quint8:
warn(
f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
)
wt_scale, wt_zp = observer.calculate_qparams()
qweight = torch.quantize_per_tensor(
float_wt, float(wt_scale), int(wt_zp), torch.quint8)

View File

@ -22,6 +22,7 @@ from torch.ao.quantization.fake_quantize import (
from .observer import (
_PartialWrapper,
MinMaxObserver,
HistogramObserver,
MovingAverageMinMaxObserver,
NoopObserver,
@ -56,6 +57,7 @@ __all__ = [
"per_channel_dynamic_qconfig",
"float_qparams_weight_only_qconfig",
"float_qparams_weight_only_qconfig_4bit",
"default_quint8_weight_qconfig",
"default_qat_qconfig",
"default_dynamic_qat_qconfig",
"default_weight_only_qconfig",
@ -74,6 +76,7 @@ __all__ = [
"get_default_qat_qconfig_dict",
"QConfigAny",
"qconfig_equals",
]
class QConfig(namedtuple('QConfig', ['activation', 'weight'])):
@ -305,6 +308,8 @@ default_embedding_qat_qconfig = QConfig(activation=NoopObserver.with_args(dtype=
default_embedding_qat_qconfig_4bit = QConfig(activation=NoopObserver.with_args(dtype=torch.float32),
weight=default_embedding_fake_quant_4bit)
default_quint8_weight_qconfig = QConfig(activation=HistogramObserver, weight=MinMaxObserver)
def get_default_qat_qconfig(backend='x86', version=1):
"""
Returns the default QAT qconfig for the specified backend.

View File

@ -22,7 +22,8 @@ from .qconfig import (
get_default_qconfig,
get_default_qat_qconfig,
QConfig,
QConfigAny
QConfigAny,
default_quint8_weight_qconfig
)
@ -92,6 +93,7 @@ def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QC
.set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \
.set_object_type(torch.nn.functional.layer_norm, qconfig_layernorm) \
.set_object_type(torch.nn.LayerNorm, qconfig_layernorm) \
.set_object_type(torch.nn.PReLU, default_quint8_weight_qconfig) \
# Use special observers for ops with fixed qparams
fixed_qparams_observer_to_qconfig: Dict[Any, QConfigAny] = {}