mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
eager quant: convert mapping for fused QAT Linear-Bn1d (#72796)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72796
Adds the eager mode convert mappint for fused QAT Linear-Bn1d module.
Test Plan:
```
python test/test_quantization.py TestQuantizeEagerQATNumerics.test_linear_bn_workflow
```
Imported from OSS
Reviewed By: jerryzh168
Differential Revision: D34213150
fbshipit-source-id: c08b5eb843dea673fd07c6b7b93dcd3ba03eaec2
(cherry picked from commit 722edfe676)
This commit is contained in:
parent
e73eaffd3b
commit
1c0df26597
|
|
@ -1006,6 +1006,22 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
|||
r2 = m(data)
|
||||
self.assertTrue(torch.allclose(r1, r2))
|
||||
|
||||
@override_qengines
|
||||
def test_linear_bn_workflow(self):
|
||||
qengine = torch.backends.quantized.engine
|
||||
m = nn.Sequential(
|
||||
QuantStub(),
|
||||
nn.Linear(4, 4),
|
||||
nn.BatchNorm1d(4),
|
||||
)
|
||||
data = torch.randn(4, 4)
|
||||
m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
||||
m = torch.ao.quantization.fuse_modules_qat(m, [['1', '2']])
|
||||
mp = prepare_qat(m)
|
||||
mp(data)
|
||||
mq = convert(mp)
|
||||
self.assertTrue(type(mq[1]) == nnq.Linear)
|
||||
self.assertTrue(type(mq[2]) == nn.Identity)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
|
|||
nnqatd.Linear,
|
||||
nnqd.Linear,
|
||||
nniqat.LinearReLU,
|
||||
nniqat.LinearBn1d,
|
||||
nn.modules.linear.NonDynamicallyQuantizableLinear,
|
||||
]),
|
||||
# linear functionals
|
||||
|
|
@ -572,6 +573,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
|||
nniqat.ConvReLU2d,
|
||||
nniqat.ConvReLU3d,
|
||||
nniqat.LinearReLU,
|
||||
nniqat.LinearBn1d,
|
||||
nniqd.LinearReLU,
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
|||
nniqat.ConvReLU2d: nniq.ConvReLU2d,
|
||||
nniqat.ConvReLU3d: nniq.ConvReLU3d,
|
||||
nniqat.LinearReLU: nniq.LinearReLU,
|
||||
nniqat.LinearBn1d: nnq.Linear,
|
||||
# QAT modules:
|
||||
nnqat.Linear: nnq.Linear,
|
||||
nnqat.Conv2d: nnq.Conv2d,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ import torch
|
|||
|
||||
import torch.nn as nn
|
||||
import torch.nn.intrinsic as nni
|
||||
import torch.nn.intrinsic.qat as nniqat
|
||||
from torch.nn.quantized.modules.utils import _quantize_weight, hide_packed_params_repr, ReferenceableQuantizedModule
|
||||
from torch.nn.utils.fusion import fuse_linear_bn_weights
|
||||
from typing import Optional
|
||||
|
||||
class LinearPackedParams(torch.nn.Module):
|
||||
|
|
@ -239,7 +241,10 @@ class Linear(ReferenceableQuantizedModule):
|
|||
utilities or provided by the user
|
||||
"""
|
||||
if hasattr(mod, 'weight_fake_quant'):
|
||||
# assert type(mod) == QATLinear, 'training mode nnq.Linear.from_float only works for nn.qat.Linear'
|
||||
if type(mod) == nniqat.LinearBn1d:
|
||||
mod.weight, mod.bias = fuse_linear_bn_weights(
|
||||
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
|
||||
mod.bn.eps, mod.bn.weight, mod.bn.bias)
|
||||
weight_post_process = mod.weight_fake_quant
|
||||
activation_post_process = mod.activation_post_process
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user