mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ao] fixing quantized prelu workflow (#103455)
Summary: https://github.com/pytorch/pytorch/issues/100654 noticed prelu was not running its observers when the quantization flow was being run, this was a bug which is now fixed and the relevant prelu tests also now check for this. Also added a corrected observer for PReLU to qconfig_mapping Test Plan: python test/test_quantization.py TestStaticQuantizedModule.test_prelu Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/103455 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
8a500f0be6
commit
8176cd8c0f
|
|
@ -1314,29 +1314,38 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
|||
offsets, set_qconfig, is_emb_bag=True, dtype=qdtype)
|
||||
|
||||
def test_prelu(self):
|
||||
x = torch.randn((4, 4, 4, 4), dtype=torch.float)
|
||||
qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8)
|
||||
for num_parameters in range(1, 10):
|
||||
x = torch.randn(4, num_parameters, 4)
|
||||
qx = torch.quantize_per_tensor_dynamic(x, dtype=torch.quint8, reduce_range=False)
|
||||
|
||||
# num_parameters = 1
|
||||
prelu_module = nnq.PReLU(output_scale=1.0, output_zero_point=0, num_parameters=1)
|
||||
w = torch.randn(1, dtype=torch.float)
|
||||
qw = torch.quantize_per_tensor(w, 1.0, 0, dtype=torch.quint8)
|
||||
prelu_module.set_weight(qw)
|
||||
qy = prelu_module(qx)
|
||||
qy_ref = torch.prelu(qx, qw)
|
||||
|
||||
self.assertEqual(qy_ref, qy,
|
||||
msg="PReLU module API failed")
|
||||
f_prelu = torch.nn.PReLU(num_parameters=num_parameters)
|
||||
f_prelu.weight = torch.nn.Parameter(torch.randn(num_parameters).abs())
|
||||
f_prelu.qconfig = torch.ao.quantization.QConfig(
|
||||
activation=torch.ao.quantization.default_observer,
|
||||
weight=torch.ao.quantization.default_observer,)
|
||||
f_prelu.activation_post_process = f_prelu.qconfig.activation()
|
||||
f_prelu.activation_post_process(f_prelu(x))
|
||||
q_prelu = nnq.PReLU.from_float(f_prelu)
|
||||
w_obs = f_prelu.qconfig.weight()
|
||||
w_obs(f_prelu.weight)
|
||||
w_scale, w_zp = w_obs.calculate_qparams()
|
||||
q_prelu_weight = torch.quantize_per_tensor(
|
||||
f_prelu.weight,
|
||||
dtype=torch.quint8,
|
||||
scale=w_scale,
|
||||
zero_point=w_zp
|
||||
).dequantize()
|
||||
|
||||
# num_parameters = num_channels
|
||||
prelu_module = nnq.PReLU(output_scale=1.0, output_zero_point=0, num_parameters=4)
|
||||
w = torch.randn(4, dtype=torch.float)
|
||||
qw = torch.quantize_per_tensor(w, 1.0, 0, dtype=torch.quint8)
|
||||
prelu_module.set_weight(qw)
|
||||
qy = prelu_module(qx)
|
||||
qy_ref = torch.prelu(qx, qw)
|
||||
self.assertEqual(qy_ref, qy,
|
||||
msg="PReLU module API failed")
|
||||
# check that the weight makes sense
|
||||
self.assertEqual(q_prelu.weight.dequantize(), q_prelu_weight)
|
||||
f_prelu.weight = torch.nn.Parameter(q_prelu.weight.dequantize())
|
||||
qy = q_prelu(qx)
|
||||
qy_ref = torch.quantize_per_tensor(
|
||||
f_prelu(qx.dequantize()), q_prelu.scale, q_prelu.zero_point, dtype=torch.quint8
|
||||
)
|
||||
# check that the output makes sense
|
||||
self.assertEqual(qy, qy_ref, atol=.1, rtol=.1)
|
||||
|
||||
def test_channel_shuffle(self):
|
||||
"""Tests the correctness of the ChannelShuffle module.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
|
||||
from warnings import warn
|
||||
__all__ = [
|
||||
"ReLU6",
|
||||
"Hardswish",
|
||||
|
|
@ -271,6 +271,11 @@ class PReLU(torch.nn.Module):
|
|||
qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
|
||||
float_wt = mod.weight.float()
|
||||
observer = mod.qconfig.weight()
|
||||
observer(float_wt)
|
||||
if observer.dtype != torch.quint8:
|
||||
warn(
|
||||
f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
|
||||
)
|
||||
wt_scale, wt_zp = observer.calculate_qparams()
|
||||
qweight = torch.quantize_per_tensor(
|
||||
float_wt, float(wt_scale), int(wt_zp), torch.quint8)
|
||||
|
|
@ -282,6 +287,11 @@ class PReLU(torch.nn.Module):
|
|||
qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
|
||||
float_wt = mod.weight.float()
|
||||
observer = mod.qconfig.weight()
|
||||
observer(float_wt)
|
||||
if observer.dtype != torch.quint8:
|
||||
warn(
|
||||
f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
|
||||
)
|
||||
wt_scale, wt_zp = observer.calculate_qparams()
|
||||
qweight = torch.quantize_per_tensor(
|
||||
float_wt, float(wt_scale), int(wt_zp), torch.quint8)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from torch.ao.quantization.fake_quantize import (
|
|||
|
||||
from .observer import (
|
||||
_PartialWrapper,
|
||||
MinMaxObserver,
|
||||
HistogramObserver,
|
||||
MovingAverageMinMaxObserver,
|
||||
NoopObserver,
|
||||
|
|
@ -56,6 +57,7 @@ __all__ = [
|
|||
"per_channel_dynamic_qconfig",
|
||||
"float_qparams_weight_only_qconfig",
|
||||
"float_qparams_weight_only_qconfig_4bit",
|
||||
"default_quint8_weight_qconfig",
|
||||
"default_qat_qconfig",
|
||||
"default_dynamic_qat_qconfig",
|
||||
"default_weight_only_qconfig",
|
||||
|
|
@ -74,6 +76,7 @@ __all__ = [
|
|||
"get_default_qat_qconfig_dict",
|
||||
"QConfigAny",
|
||||
"qconfig_equals",
|
||||
|
||||
]
|
||||
|
||||
class QConfig(namedtuple('QConfig', ['activation', 'weight'])):
|
||||
|
|
@ -305,6 +308,8 @@ default_embedding_qat_qconfig = QConfig(activation=NoopObserver.with_args(dtype=
|
|||
default_embedding_qat_qconfig_4bit = QConfig(activation=NoopObserver.with_args(dtype=torch.float32),
|
||||
weight=default_embedding_fake_quant_4bit)
|
||||
|
||||
default_quint8_weight_qconfig = QConfig(activation=HistogramObserver, weight=MinMaxObserver)
|
||||
|
||||
def get_default_qat_qconfig(backend='x86', version=1):
|
||||
"""
|
||||
Returns the default QAT qconfig for the specified backend.
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@ from .qconfig import (
|
|||
get_default_qconfig,
|
||||
get_default_qat_qconfig,
|
||||
QConfig,
|
||||
QConfigAny
|
||||
QConfigAny,
|
||||
default_quint8_weight_qconfig
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -92,6 +93,7 @@ def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QC
|
|||
.set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \
|
||||
.set_object_type(torch.nn.functional.layer_norm, qconfig_layernorm) \
|
||||
.set_object_type(torch.nn.LayerNorm, qconfig_layernorm) \
|
||||
.set_object_type(torch.nn.PReLU, default_quint8_weight_qconfig) \
|
||||
|
||||
# Use special observers for ops with fixed qparams
|
||||
fixed_qparams_observer_to_qconfig: Dict[Any, QConfigAny] = {}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user