quant: add QAT fused Linear-Bn1d [1/x]: prepared module (#72431)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72431

Adds support for a fused QAT observed module for `Linear` followed by
`BatchNorm1d`. In this PR, only the support for prepared module with
fake_quants in the right places is added.

A future PR will add support for `convert`, and tests for eager and FX
graph mode workflows.

Similar to conv-bn, we rescale the weight before applying the fake
quant, and undo the rescaling after the linear operation.

Test Plan:
```
python test/test_quantization.py TestQuantizeEagerQATNumerics.test_linear_bn
```

Imported from OSS

Reviewed By: jerryzh168, raghuramank10000

Differential Revision: D34044427

fbshipit-source-id: 47a519173939ca4824d2c6e6ea7a599764a8ed10
This commit is contained in:
Vasiliy Kuznetsov 2022-02-18 05:08:02 -08:00 committed by Facebook GitHub Bot
parent 0ee7c1cc39
commit bfc75fe078
7 changed files with 197 additions and 1 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["oncall: quantization"]
import copy
import math
import torch
import torch.nn as nn
@ -10,6 +11,7 @@ from torch.nn.modules.utils import _pair
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
import torch.nn.intrinsic.qat as nniqat
import torch.nn.qat.dynamic as nnqatd
from torch.ao.quantization import (
prepare,
@ -984,6 +986,27 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
qat_op_optim.step()
qat_ref_op_optim.step()
@override_qengines
def test_linear_bn_numerics(self):
qengine = torch.backends.quantized.engine
m_ref = nn.Sequential(
nn.Linear(4, 4),
nn.BatchNorm1d(4),
)
m_ref_copy = copy.deepcopy(m_ref)
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']])
qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
m_ref_copy[0].qconfig = qconfig
m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
# without fake_quants, fused QAT module should match fp32 module
m.apply(torch.quantization.disable_fake_quant)
data = torch.randn(4, 4)
r1 = m_ref(data)
r2 = m(data)
self.assertTrue(torch.allclose(r1, r2))
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_quantization.py TESTNAME\n\n"

View File

@ -114,7 +114,12 @@ def fuse_linear_bn(is_qat, linear, bn):
if is_qat:
# TODO: remove the assert later
assert linear.training, "qat is only supported when linear.training is True currently"
raise Exception("Fusing Linear+BatchNorm not yet supported in training.")
assert bn.num_features == linear.out_features,\
"Output features of Linear must match num_features of BatchNorm1d"
assert bn.affine, "Only support fusing BatchNorm1d with affine set to True"
assert bn.track_running_stats,\
"Only support fusing BatchNorm1d with tracking_running_stats set to True"
return nni.LinearBn1d(linear, bn)
else:
return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)

View File

@ -99,6 +99,7 @@ DEFAULT_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = {
nni.ConvReLU2d: nniqat.ConvReLU2d,
nni.ConvReLU3d: nniqat.ConvReLU3d,
nni.LinearReLU: nniqat.LinearReLU,
nni.LinearBn1d: nniqat.LinearBn1d,
}
# Default map for swapping dynamic modules

View File

@ -11,6 +11,7 @@ from .fused import ConvReLU3d
from .fused import LinearReLU
from .fused import BNReLU2d
from .fused import BNReLU3d
from .fused import LinearBn1d
__all__ = [
@ -27,4 +28,5 @@ __all__ = [
'LinearReLU',
'BNReLU2d',
'BNReLU3d',
'LinearBn1d',
]

View File

@ -113,3 +113,12 @@ class BNReLU3d(_FusedModule):
'Incorrect types for input modules{}{}'.format(
type(batch_norm), type(relu))
super().__init__(batch_norm, relu)
class LinearBn1d(_FusedModule):
r"""This is a sequential container which calls the Linear and BatchNorm1d modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, linear, bn):
assert type(linear) == Linear and type(bn) == BatchNorm1d, \
'Incorrect types for input modules{}{}'.format(type(linear), type(bn))
super().__init__(linear, bn)

View File

@ -1,4 +1,5 @@
from .linear_relu import LinearReLU
from .linear_fused import LinearBn1d
from .conv_fused import (
ConvBn1d,
ConvBn2d,
@ -14,6 +15,7 @@ from .conv_fused import (
__all__ = [
"LinearReLU",
"LinearBn1d",
"ConvReLU2d",
"ConvReLU3d",
"ConvBn1d",

View File

@ -0,0 +1,154 @@
import torch
import torch.nn as nn
import torch.nn.intrinsic as nni
import torch.nn.functional as F
from torch.nn import init
from torch.nn.parameter import Parameter
class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
r"""
A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached
with FakeQuantize modules for weight, used in quantization aware training.
We combined the interface of :class:`torch.nn.Linear` and
:class:torch.nn.BatchNorm1d`.
Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized
to default.
Attributes:
freeze_bn:
weight_fake_quant: fake quant module for weight
"""
def __init__(self,
# Linear args
in_features, out_features, bias=True,
# BatchNorm1d args
# num_features: out_features
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
nn.modules.linear.Linear.__init__(self, in_features, out_features, bias)
assert qconfig, 'qconfig must be provded for QAT module'
self.qconfig = qconfig
self.freeze_bn = freeze_bn if self.training else True
self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True)
self.weight_fake_quant = self.qconfig.weight()
if bias:
self.bias = Parameter(torch.empty(out_features))
else:
self.register_parameter('bias', None)
self.reset_bn_parameters()
# this needs to be called after reset_bn_parameters,
# as they modify the same state
if self.training:
if freeze_bn:
self.freeze_bn_stats()
else:
self.update_bn_stats()
else:
self.freeze_bn_stats()
def reset_running_stats(self):
self.bn.reset_running_stats()
def reset_bn_parameters(self):
self.bn.reset_running_stats()
init.uniform_(self.bn.weight)
init.zeros_(self.bn.bias)
def reset_parameters(self):
super(LinearBn1d, self).reset_parameters()
def update_bn_stats(self):
self.freeze_bn = False
self.bn.training = True
return self
def freeze_bn_stats(self):
self.freeze_bn = True
self.bn.training = False
return self
def forward(self, input):
assert self.bn.running_var is not None
# Scale the linear weights by BN's running statistics to reduce
# weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18
# for motivation.
#
# Instead of
#
# x1 = F.linear(x0, fq(w), b)
# x2 = self.bn(x1)
#
# We have
#
# # scale the weight by previous batch's running statistics
# scale_factor = bn.w / bn.running_std_from_prev_batch
# # do the linear transformation without bias
# x1_scaled = F.linear(x0, fq(w * scale_factor), 0)
# # reverse the scaling and add original bias
# x1_orig = x1_scaled / scale_factor + b
# x2 = self.bn(x1_orig)
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
scale_factor = self.bn.weight / running_std
weight_shape = [1] * len(self.weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(self.weight.shape)
bias_shape[1] = -1
scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
if self.bias is not None:
zero_bias = torch.zeros_like(self.bias)
else:
zero_bias = torch.zeros(self.out_features, device=scaled_weight.device)
linear_out = F.linear(input, scaled_weight, zero_bias)
linear_out_orig = linear_out / scale_factor.reshape(bias_shape)
if self.bias is not None:
linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape)
bn_out = self.bn(linear_out_orig)
return bn_out
def train(self, mode=True):
"""
Batchnorm's training behavior is using the self.training flag. Prevent
changing it if BN is frozen. This makes sure that calling `model.train()`
on a model with a frozen BN will behave properly.
"""
self.training = mode
if not self.freeze_bn:
for module in self.children():
module.train(mode)
return self
@classmethod
def from_float(cls, mod):
r"""Create a qat module from a float module or qparams_dict
Args: `mod' a float module, either produced by torch.ao.quantization
utilities or directly from user
"""
assert type(mod) == nni.LinearBn1d, 'qat.' + cls.__name__ + \
'.from_float only works for ' + nni.LinearBn1d.__name__
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid config'
qconfig = mod.qconfig
linear, bn = mod[0], mod[1]
qat_linearbn = cls(linear.in_features, linear.out_features, linear.bias is not None,
bn.eps, bn.momentum,
False, qconfig)
qat_linearbn.weight = linear.weight
qat_linearbn.bias = linear.bias
qat_linearbn.bn.weight = bn.weight
qat_linearbn.bn.bias = bn.bias
qat_linearbn.bn.running_mean = bn.running_mean
qat_linearbn.bn.running_var = bn.running_var
qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked
return qat_linearbn