mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
eager quant: convert mapping for fused QAT Linear-Bn1d (#72796)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72796
Adds the eager mode convert mappint for fused QAT Linear-Bn1d module.
Test Plan:
```
python test/test_quantization.py TestQuantizeEagerQATNumerics.test_linear_bn_workflow
```
Imported from OSS
Reviewed By: jerryzh168
Differential Revision: D34213150
fbshipit-source-id: c08b5eb843dea673fd07c6b7b93dcd3ba03eaec2
(cherry picked from commit 722edfe676)
This commit is contained in:
parent
e73eaffd3b
commit
1c0df26597
|
|
@ -1006,6 +1006,22 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
r2 = m(data)
|
r2 = m(data)
|
||||||
self.assertTrue(torch.allclose(r1, r2))
|
self.assertTrue(torch.allclose(r1, r2))
|
||||||
|
|
||||||
|
@override_qengines
|
||||||
|
def test_linear_bn_workflow(self):
|
||||||
|
qengine = torch.backends.quantized.engine
|
||||||
|
m = nn.Sequential(
|
||||||
|
QuantStub(),
|
||||||
|
nn.Linear(4, 4),
|
||||||
|
nn.BatchNorm1d(4),
|
||||||
|
)
|
||||||
|
data = torch.randn(4, 4)
|
||||||
|
m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
||||||
|
m = torch.ao.quantization.fuse_modules_qat(m, [['1', '2']])
|
||||||
|
mp = prepare_qat(m)
|
||||||
|
mp(data)
|
||||||
|
mq = convert(mp)
|
||||||
|
self.assertTrue(type(mq[1]) == nnq.Linear)
|
||||||
|
self.assertTrue(type(mq[2]) == nn.Identity)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
|
||||||
nnqatd.Linear,
|
nnqatd.Linear,
|
||||||
nnqd.Linear,
|
nnqd.Linear,
|
||||||
nniqat.LinearReLU,
|
nniqat.LinearReLU,
|
||||||
|
nniqat.LinearBn1d,
|
||||||
nn.modules.linear.NonDynamicallyQuantizableLinear,
|
nn.modules.linear.NonDynamicallyQuantizableLinear,
|
||||||
]),
|
]),
|
||||||
# linear functionals
|
# linear functionals
|
||||||
|
|
@ -572,6 +573,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
||||||
nniqat.ConvReLU2d,
|
nniqat.ConvReLU2d,
|
||||||
nniqat.ConvReLU3d,
|
nniqat.ConvReLU3d,
|
||||||
nniqat.LinearReLU,
|
nniqat.LinearReLU,
|
||||||
|
nniqat.LinearBn1d,
|
||||||
nniqd.LinearReLU,
|
nniqd.LinearReLU,
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,7 @@ DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
||||||
nniqat.ConvReLU2d: nniq.ConvReLU2d,
|
nniqat.ConvReLU2d: nniq.ConvReLU2d,
|
||||||
nniqat.ConvReLU3d: nniq.ConvReLU3d,
|
nniqat.ConvReLU3d: nniq.ConvReLU3d,
|
||||||
nniqat.LinearReLU: nniq.LinearReLU,
|
nniqat.LinearReLU: nniq.LinearReLU,
|
||||||
|
nniqat.LinearBn1d: nnq.Linear,
|
||||||
# QAT modules:
|
# QAT modules:
|
||||||
nnqat.Linear: nnq.Linear,
|
nnqat.Linear: nnq.Linear,
|
||||||
nnqat.Conv2d: nnq.Conv2d,
|
nnqat.Conv2d: nnq.Conv2d,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,9 @@ import torch
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.intrinsic as nni
|
import torch.nn.intrinsic as nni
|
||||||
|
import torch.nn.intrinsic.qat as nniqat
|
||||||
from torch.nn.quantized.modules.utils import _quantize_weight, hide_packed_params_repr, ReferenceableQuantizedModule
|
from torch.nn.quantized.modules.utils import _quantize_weight, hide_packed_params_repr, ReferenceableQuantizedModule
|
||||||
|
from torch.nn.utils.fusion import fuse_linear_bn_weights
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
class LinearPackedParams(torch.nn.Module):
|
class LinearPackedParams(torch.nn.Module):
|
||||||
|
|
@ -239,7 +241,10 @@ class Linear(ReferenceableQuantizedModule):
|
||||||
utilities or provided by the user
|
utilities or provided by the user
|
||||||
"""
|
"""
|
||||||
if hasattr(mod, 'weight_fake_quant'):
|
if hasattr(mod, 'weight_fake_quant'):
|
||||||
# assert type(mod) == QATLinear, 'training mode nnq.Linear.from_float only works for nn.qat.Linear'
|
if type(mod) == nniqat.LinearBn1d:
|
||||||
|
mod.weight, mod.bias = fuse_linear_bn_weights(
|
||||||
|
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
|
||||||
|
mod.bn.eps, mod.bn.weight, mod.bn.bias)
|
||||||
weight_post_process = mod.weight_fake_quant
|
weight_post_process = mod.weight_fake_quant
|
||||||
activation_post_process = mod.activation_post_process
|
activation_post_process = mod.activation_post_process
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user