mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
quant: add QAT fused Linear-Bn1d [1/x]: prepared module (#72431)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72431 Adds support for a fused QAT observed module for `Linear` followed by `BatchNorm1d`. In this PR, only the support for prepared module with fake_quants in the right places is added. A future PR will add support for `convert`, and tests for eager and FX graph mode workflows. Similar to conv-bn, we rescale the weight before applying the fake quant, and undo the rescaling after the linear operation. Test Plan: ``` python test/test_quantization.py TestQuantizeEagerQATNumerics.test_linear_bn ``` Imported from OSS Reviewed By: jerryzh168, raghuramank10000 Differential Revision: D34044427 fbshipit-source-id: 47a519173939ca4824d2c6e6ea7a599764a8ed10
This commit is contained in:
parent
0ee7c1cc39
commit
bfc75fe078
|
|
@ -1,5 +1,6 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import copy
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -10,6 +11,7 @@ from torch.nn.modules.utils import _pair
|
|||
import torch.nn.quantized as nnq
|
||||
import torch.nn.quantized.dynamic as nnqd
|
||||
import torch.nn.qat as nnqat
|
||||
import torch.nn.intrinsic.qat as nniqat
|
||||
import torch.nn.qat.dynamic as nnqatd
|
||||
from torch.ao.quantization import (
|
||||
prepare,
|
||||
|
|
@ -984,6 +986,27 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
|||
qat_op_optim.step()
|
||||
qat_ref_op_optim.step()
|
||||
|
||||
@override_qengines
|
||||
def test_linear_bn_numerics(self):
|
||||
qengine = torch.backends.quantized.engine
|
||||
m_ref = nn.Sequential(
|
||||
nn.Linear(4, 4),
|
||||
nn.BatchNorm1d(4),
|
||||
)
|
||||
m_ref_copy = copy.deepcopy(m_ref)
|
||||
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']])
|
||||
qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
||||
m_ref_copy[0].qconfig = qconfig
|
||||
m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
|
||||
|
||||
# without fake_quants, fused QAT module should match fp32 module
|
||||
m.apply(torch.quantization.disable_fake_quant)
|
||||
data = torch.randn(4, 4)
|
||||
r1 = m_ref(data)
|
||||
r2 = m(data)
|
||||
self.assertTrue(torch.allclose(r1, r2))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_quantization.py TESTNAME\n\n"
|
||||
|
|
|
|||
|
|
@ -114,7 +114,12 @@ def fuse_linear_bn(is_qat, linear, bn):
|
|||
if is_qat:
|
||||
# TODO: remove the assert later
|
||||
assert linear.training, "qat is only supported when linear.training is True currently"
|
||||
raise Exception("Fusing Linear+BatchNorm not yet supported in training.")
|
||||
assert bn.num_features == linear.out_features,\
|
||||
"Output features of Linear must match num_features of BatchNorm1d"
|
||||
assert bn.affine, "Only support fusing BatchNorm1d with affine set to True"
|
||||
assert bn.track_running_stats,\
|
||||
"Only support fusing BatchNorm1d with tracking_running_stats set to True"
|
||||
return nni.LinearBn1d(linear, bn)
|
||||
else:
|
||||
return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
|
||||
|
||||
|
|
|
|||
|
|
@ -99,6 +99,7 @@ DEFAULT_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
|||
nni.ConvReLU2d: nniqat.ConvReLU2d,
|
||||
nni.ConvReLU3d: nniqat.ConvReLU3d,
|
||||
nni.LinearReLU: nniqat.LinearReLU,
|
||||
nni.LinearBn1d: nniqat.LinearBn1d,
|
||||
}
|
||||
|
||||
# Default map for swapping dynamic modules
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from .fused import ConvReLU3d
|
|||
from .fused import LinearReLU
|
||||
from .fused import BNReLU2d
|
||||
from .fused import BNReLU3d
|
||||
from .fused import LinearBn1d
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -27,4 +28,5 @@ __all__ = [
|
|||
'LinearReLU',
|
||||
'BNReLU2d',
|
||||
'BNReLU3d',
|
||||
'LinearBn1d',
|
||||
]
|
||||
|
|
|
|||
|
|
@ -113,3 +113,12 @@ class BNReLU3d(_FusedModule):
|
|||
'Incorrect types for input modules{}{}'.format(
|
||||
type(batch_norm), type(relu))
|
||||
super().__init__(batch_norm, relu)
|
||||
|
||||
|
||||
class LinearBn1d(_FusedModule):
|
||||
r"""This is a sequential container which calls the Linear and BatchNorm1d modules.
|
||||
During quantization this will be replaced with the corresponding fused module."""
|
||||
def __init__(self, linear, bn):
|
||||
assert type(linear) == Linear and type(bn) == BatchNorm1d, \
|
||||
'Incorrect types for input modules{}{}'.format(type(linear), type(bn))
|
||||
super().__init__(linear, bn)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from .linear_relu import LinearReLU
|
||||
from .linear_fused import LinearBn1d
|
||||
from .conv_fused import (
|
||||
ConvBn1d,
|
||||
ConvBn2d,
|
||||
|
|
@ -14,6 +15,7 @@ from .conv_fused import (
|
|||
|
||||
__all__ = [
|
||||
"LinearReLU",
|
||||
"LinearBn1d",
|
||||
"ConvReLU2d",
|
||||
"ConvReLU3d",
|
||||
"ConvBn1d",
|
||||
|
|
|
|||
154
torch/nn/intrinsic/qat/modules/linear_fused.py
Normal file
154
torch/nn/intrinsic/qat/modules/linear_fused.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.intrinsic as nni
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import init
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
|
||||
r"""
|
||||
A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached
|
||||
with FakeQuantize modules for weight, used in quantization aware training.
|
||||
|
||||
We combined the interface of :class:`torch.nn.Linear` and
|
||||
:class:torch.nn.BatchNorm1d`.
|
||||
|
||||
Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized
|
||||
to default.
|
||||
|
||||
Attributes:
|
||||
freeze_bn:
|
||||
weight_fake_quant: fake quant module for weight
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
# Linear args
|
||||
in_features, out_features, bias=True,
|
||||
# BatchNorm1d args
|
||||
# num_features: out_features
|
||||
eps=1e-05, momentum=0.1,
|
||||
# affine: True
|
||||
# track_running_stats: True
|
||||
# Args for this module
|
||||
freeze_bn=False,
|
||||
qconfig=None):
|
||||
nn.modules.linear.Linear.__init__(self, in_features, out_features, bias)
|
||||
assert qconfig, 'qconfig must be provded for QAT module'
|
||||
self.qconfig = qconfig
|
||||
self.freeze_bn = freeze_bn if self.training else True
|
||||
self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True)
|
||||
self.weight_fake_quant = self.qconfig.weight()
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(out_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_bn_parameters()
|
||||
|
||||
# this needs to be called after reset_bn_parameters,
|
||||
# as they modify the same state
|
||||
if self.training:
|
||||
if freeze_bn:
|
||||
self.freeze_bn_stats()
|
||||
else:
|
||||
self.update_bn_stats()
|
||||
else:
|
||||
self.freeze_bn_stats()
|
||||
|
||||
def reset_running_stats(self):
|
||||
self.bn.reset_running_stats()
|
||||
|
||||
def reset_bn_parameters(self):
|
||||
self.bn.reset_running_stats()
|
||||
init.uniform_(self.bn.weight)
|
||||
init.zeros_(self.bn.bias)
|
||||
|
||||
def reset_parameters(self):
|
||||
super(LinearBn1d, self).reset_parameters()
|
||||
|
||||
def update_bn_stats(self):
|
||||
self.freeze_bn = False
|
||||
self.bn.training = True
|
||||
return self
|
||||
|
||||
def freeze_bn_stats(self):
|
||||
self.freeze_bn = True
|
||||
self.bn.training = False
|
||||
return self
|
||||
|
||||
def forward(self, input):
|
||||
assert self.bn.running_var is not None
|
||||
|
||||
# Scale the linear weights by BN's running statistics to reduce
|
||||
# weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18
|
||||
# for motivation.
|
||||
#
|
||||
# Instead of
|
||||
#
|
||||
# x1 = F.linear(x0, fq(w), b)
|
||||
# x2 = self.bn(x1)
|
||||
#
|
||||
# We have
|
||||
#
|
||||
# # scale the weight by previous batch's running statistics
|
||||
# scale_factor = bn.w / bn.running_std_from_prev_batch
|
||||
# # do the linear transformation without bias
|
||||
# x1_scaled = F.linear(x0, fq(w * scale_factor), 0)
|
||||
# # reverse the scaling and add original bias
|
||||
# x1_orig = x1_scaled / scale_factor + b
|
||||
# x2 = self.bn(x1_orig)
|
||||
|
||||
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
|
||||
scale_factor = self.bn.weight / running_std
|
||||
weight_shape = [1] * len(self.weight.shape)
|
||||
weight_shape[0] = -1
|
||||
bias_shape = [1] * len(self.weight.shape)
|
||||
bias_shape[1] = -1
|
||||
scaled_weight = self.weight_fake_quant(self.weight * scale_factor.reshape(weight_shape))
|
||||
if self.bias is not None:
|
||||
zero_bias = torch.zeros_like(self.bias)
|
||||
else:
|
||||
zero_bias = torch.zeros(self.out_features, device=scaled_weight.device)
|
||||
linear_out = F.linear(input, scaled_weight, zero_bias)
|
||||
linear_out_orig = linear_out / scale_factor.reshape(bias_shape)
|
||||
if self.bias is not None:
|
||||
linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape)
|
||||
bn_out = self.bn(linear_out_orig)
|
||||
return bn_out
|
||||
|
||||
def train(self, mode=True):
|
||||
"""
|
||||
Batchnorm's training behavior is using the self.training flag. Prevent
|
||||
changing it if BN is frozen. This makes sure that calling `model.train()`
|
||||
on a model with a frozen BN will behave properly.
|
||||
"""
|
||||
self.training = mode
|
||||
if not self.freeze_bn:
|
||||
for module in self.children():
|
||||
module.train(mode)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod):
|
||||
r"""Create a qat module from a float module or qparams_dict
|
||||
|
||||
Args: `mod' a float module, either produced by torch.ao.quantization
|
||||
utilities or directly from user
|
||||
"""
|
||||
assert type(mod) == nni.LinearBn1d, 'qat.' + cls.__name__ + \
|
||||
'.from_float only works for ' + nni.LinearBn1d.__name__
|
||||
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
||||
assert mod.qconfig, 'Input float module must have a valid config'
|
||||
qconfig = mod.qconfig
|
||||
linear, bn = mod[0], mod[1]
|
||||
qat_linearbn = cls(linear.in_features, linear.out_features, linear.bias is not None,
|
||||
bn.eps, bn.momentum,
|
||||
False, qconfig)
|
||||
qat_linearbn.weight = linear.weight
|
||||
qat_linearbn.bias = linear.bias
|
||||
qat_linearbn.bn.weight = bn.weight
|
||||
qat_linearbn.bn.bias = bn.bias
|
||||
qat_linearbn.bn.running_mean = bn.running_mean
|
||||
qat_linearbn.bn.running_var = bn.running_var
|
||||
qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked
|
||||
return qat_linearbn
|
||||
Loading…
Reference in New Issue
Block a user