eager quant: convert mapping for fused QAT Linear-Bn1d (#72796)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72796

Adds the eager mode convert mappint for fused QAT Linear-Bn1d module.

Test Plan:
```
python test/test_quantization.py TestQuantizeEagerQATNumerics.test_linear_bn_workflow
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D34213150

fbshipit-source-id: c08b5eb843dea673fd07c6b7b93dcd3ba03eaec2
(cherry picked from commit 722edfe676)
This commit is contained in:
Vasiliy Kuznetsov 2022-02-18 05:08:02 -08:00 committed by PyTorch MergeBot
parent e73eaffd3b
commit 1c0df26597
4 changed files with 25 additions and 1 deletions

View File

@ -1006,6 +1006,22 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
r2 = m(data)
self.assertTrue(torch.allclose(r1, r2))
@override_qengines
def test_linear_bn_workflow(self):
qengine = torch.backends.quantized.engine
m = nn.Sequential(
QuantStub(),
nn.Linear(4, 4),
nn.BatchNorm1d(4),
)
data = torch.randn(4, 4)
m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
m = torch.ao.quantization.fuse_modules_qat(m, [['1', '2']])
mp = prepare_qat(m)
mp(data)
mq = convert(mp)
self.assertTrue(type(mq[1]) == nnq.Linear)
self.assertTrue(type(mq[2]) == nn.Identity)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"

View File

@ -80,6 +80,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
nnqatd.Linear,
nnqd.Linear,
nniqat.LinearReLU,
nniqat.LinearBn1d,
nn.modules.linear.NonDynamicallyQuantizableLinear,
]),
# linear functionals
@ -572,6 +573,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
nniqat.ConvReLU2d,
nniqat.ConvReLU3d,
nniqat.LinearReLU,
nniqat.LinearBn1d,
nniqd.LinearReLU,
])

View File

@ -77,6 +77,7 @@ DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
nniqat.ConvReLU2d: nniq.ConvReLU2d,
nniqat.ConvReLU3d: nniq.ConvReLU3d,
nniqat.LinearReLU: nniq.LinearReLU,
nniqat.LinearBn1d: nnq.Linear,
# QAT modules:
nnqat.Linear: nnq.Linear,
nnqat.Conv2d: nnq.Conv2d,

View File

@ -3,7 +3,9 @@ import torch
import torch.nn as nn
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.qat as nniqat
from torch.nn.quantized.modules.utils import _quantize_weight, hide_packed_params_repr, ReferenceableQuantizedModule
from torch.nn.utils.fusion import fuse_linear_bn_weights
from typing import Optional
class LinearPackedParams(torch.nn.Module):
@ -239,7 +241,10 @@ class Linear(ReferenceableQuantizedModule):
utilities or provided by the user
"""
if hasattr(mod, 'weight_fake_quant'):
# assert type(mod) == QATLinear, 'training mode nnq.Linear.from_float only works for nn.qat.Linear'
if type(mod) == nniqat.LinearBn1d:
mod.weight, mod.bias = fuse_linear_bn_weights(
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
mod.bn.eps, mod.bn.weight, mod.bn.bias)
weight_post_process = mod.weight_fake_quant
activation_post_process = mod.activation_post_process
else: