pytorch/test/quantization/eager/test_quantize_eager_qat.py
Aaron Gokaslan 3fe437b24b [BE]: Update flake8 to v6.1.0 and fix lints (#116591)
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
  - `assert(a == b)` -> `assert a == b`
  - `if(x > y or y < z):`->`if x > y or y < z:`
  - And `return('...')` -> `return '...'`

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
2024-01-03 06:04:44 +00:00

1106 lines
45 KiB
Python

# Owner(s): ["oncall: quantization"]
import copy
import math
import torch
import torch.nn as nn
import torch.backends.mkldnn
from torch.nn import Conv2d, BatchNorm2d, ReLU, init
from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
from torch.nn.modules.utils import _pair
import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.dynamic as nnqd
import torch.ao.nn.qat as nnqat
import torch.ao.nn.intrinsic.qat as nniqat
import torch.ao.nn.qat.dynamic as nnqatd
from torch.ao.quantization import (
prepare,
convert,
prepare_qat,
quantize_qat,
QuantStub,
DeQuantStub,
default_qconfig,
default_qat_qconfig,
default_embedding_qat_qconfig,
default_symmetric_qnnpack_qat_qconfig,
get_default_qat_qconfig,
FixedQParamsFakeQuantize,
FusedMovingAvgObsFakeQuantize,
get_embedding_qat_module_mappings,
get_embedding_static_quant_module_mappings,
NoopObserver,
)
from torch.ao.quantization.qconfig import qconfig_equals
from torch.testing._internal.common_quantization import (
DeFusedEmbeddingBagLinear,
QuantizationTestCase,
QuantStubModel,
ManualLinearQATModel,
ManualDropoutQATModel,
ManualLinearDynamicQATModel,
ManualConvLinearQATModel,
ManualConvLinearSymmQATModel,
ManualEmbeddingBagLinear,
TwoLayerLinearModel,
test_only_eval_fn,
test_only_train_fn,
)
from torch.testing._internal.common_quantized import (
override_quantized_engine,
supported_qengines,
override_qengines,
)
from torch.testing._internal.common_utils import skipIfNoXNNPACK
from hypothesis import given
from hypothesis import strategies as st
import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled()
from functools import reduce
class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
"""
Conv-BN fusion implemented with explicit folding. Useful
to verify numerical equivalency with non-folded version.
"""
def __init__(self,
# ConvNd args
in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups,
bias,
padding_mode,
# BatchNormNd args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
stride, padding, dilation, transposed,
output_padding, groups, False, padding_mode)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.eps = eps
self.momentum = momentum
self.freeze_bn = freeze_bn if self.training else True
self.num_features = out_channels
self.gamma = nn.Parameter(torch.empty(out_channels))
self.beta = nn.Parameter(torch.empty(out_channels))
self.affine = True
self.track_running_stats = True
self.register_buffer('running_mean', torch.zeros(out_channels))
self.register_buffer('running_var', torch.ones(out_channels))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
self.activation_post_process = self.qconfig.activation()
self.weight_fake_quant = self.qconfig.weight()
if bias:
self.bias = nn.Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.reset_bn_parameters()
def reset_running_stats(self):
self.running_mean.zero_()
self.running_var.fill_(1)
self.num_batches_tracked.zero_()
def reset_bn_parameters(self):
self.reset_running_stats()
init.uniform_(self.gamma)
init.zeros_(self.beta)
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def reset_parameters(self):
super().reset_parameters()
# A hack to avoid resetting on undefined parameters
if hasattr(self, 'gamma'):
self.reset_bn_parameters()
def update_bn_stats(self):
self.freeze_bn = False
return self
def freeze_bn_stats(self):
self.freeze_bn = True
return self
def _forward(self, input):
# exponential_average_factor is self.momentum set to
# (when it is available) only so that if gets updated
# in ONNX graph when this node is exported to ONNX.
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and not self.freeze_bn and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# we use running statistics from the previous batch, so this is an
# approximation of the approach mentioned in the whitepaper, but we only
# need to do one convolution in this case instead of two
running_std = torch.sqrt(self.running_var + self.eps)
scale_factor = self.gamma / running_std
scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1])
if self.bias is not None:
zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
else:
zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.dtype)
conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), zero_bias)
if self.training and not self.freeze_bn:
# recovering original conv to get original batch_mean and batch_var
if self.bias is not None:
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1])
else:
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
batch_mean = torch.mean(conv_orig, dim=[0, 2, 3])
batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False)
n = float(conv_orig.numel() / conv_orig.size()[1])
unbiased_batch_var = batch_var * (n / (n - 1))
batch_rstd = torch.ones_like(batch_var, memory_format=torch.contiguous_format) / torch.sqrt(batch_var + self.eps)
conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + \
(self.beta - self.gamma * batch_rstd * batch_mean).reshape([1, -1, 1, 1])
self.running_mean = exponential_average_factor * batch_mean.detach() + \
(1 - exponential_average_factor) * self.running_mean
self.running_var = exponential_average_factor * unbiased_batch_var.detach() + \
(1 - exponential_average_factor) * self.running_var
else:
if self.bias is None:
conv = conv + (self.beta - self.gamma * self.running_mean /
running_std).reshape([1, -1, 1, 1])
else:
conv = conv + (self.gamma * (self.bias - self.running_mean) / running_std + self.beta).reshape([1, -1, 1, 1])
return conv
def extra_repr(self):
# TODO(jerryzh): extend
return super().extra_repr()
def forward(self, input):
return self.activation_post_process(self._forward(input))
@classmethod
def from_float(cls, mod, qconfig=None):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
if not qconfig:
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
qconfig = mod.qconfig
conv, bn = mod[0], mod[1]
qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
conv.stride, conv.padding, conv.dilation,
conv.groups, conv.bias is not None,
conv.padding_mode,
bn.eps, bn.momentum,
False,
qconfig)
qat_convbn.weight = conv.weight
qat_convbn.bias = conv.bias
qat_convbn.gamma = bn.weight
qat_convbn.beta = bn.bias
qat_convbn.running_mean = bn.running_mean
qat_convbn.running_var = bn.running_var
qat_convbn.num_batches_tracked = bn.num_batches_tracked
return qat_convbn
class _ReferenceConvBn2d(_ReferenceConvBnNd, nn.Conv2d):
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvBn2d
def __init__(self,
# ConvNd args
in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=None,
padding_mode='zeros',
# BatchNorm2d args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
_ReferenceConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, False, _pair(0), groups, bias, padding_mode,
eps, momentum, freeze_bn, qconfig)
class TestQuantizeEagerQAT(QuantizationTestCase):
def setUp(self):
super().setUp()
self.embed_linear_data_train = [[torch.randint(0, 10, (12, 12), dtype=torch.long),
torch.randn((12, 1), dtype=torch.float)]
for _ in range(2)]
self.embed_data = [[torch.randint(0, 10, (12, 1))]]
def test_manual(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ManualLinearQATModel(qengine)
model = prepare_qat(model)
self.checkObservers(model)
test_only_train_fn(model, self.train_data)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.fc2), nnq.Linear)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
self.checkNoQconfig(model)
checkQuantized(model)
model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn,
[self.train_data])
checkQuantized(model)
def test_dropout(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ManualDropoutQATModel(qengine)
model = prepare_qat(model)
self.checkObservers(model)
test_only_train_fn(model, self.train_data)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.dropout), nnq.Dropout)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
self.checkNoQconfig(model)
checkQuantized(model)
model = quantize_qat(ManualDropoutQATModel(qengine), test_only_train_fn,
[self.train_data])
checkQuantized(model)
def test_eval_only_fake_quant(self):
r"""Using FakeQuant in evaluation only mode,
this is useful for estimating accuracy loss when we quantize the
network
"""
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ManualLinearQATModel(qengine)
model = prepare_qat(model)
self.checkObservers(model)
model.eval()
test_only_eval_fn(model, self.calib_data)
def test_conv_linear(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ManualConvLinearQATModel()
model = prepare_qat(model)
self.checkObservers(model)
test_only_train_fn(model, self.img_data_2d_train)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.conv), nnq.Conv2d)
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.fc2), nnq.Linear)
test_only_eval_fn(model, self.img_data_2d)
self.checkScriptable(model, self.img_data_2d)
self.checkNoQconfig(model)
checkQuantized(model)
model = ManualConvLinearQATModel()
model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train])
checkQuantized(model)
@skipIfNoXNNPACK
def test_conv_linear_symm(self):
r"""Same as test_conv_linear but with Symmetric quantization.
Supported only with qengine=qnnpack, which uses symmetric
kernels from xnnpack library."""
for qengine in supported_qengines:
if qengine != 'qnnpack':
continue
with override_quantized_engine(qengine):
model = ManualConvLinearSymmQATModel()
model = prepare_qat(model)
self.checkObservers(model)
test_only_train_fn(model, self.img_data_2d_train)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.conv), nnq.Conv2d)
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.fc2), nnq.Linear)
test_only_eval_fn(model, self.img_data_2d)
self.checkScriptable(model, self.img_data_2d)
self.checkNoQconfig(model)
checkQuantized(model)
model = ManualConvLinearSymmQATModel()
model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train])
checkQuantized(model)
def test_dynamic_qat_linear(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
# Dynamic QAT without memoryless observers should fail
with self.assertRaisesRegex(ValueError,
"Dynamic QAT requires a memoryless observer." +
"This means a MovingAverage observer with averaging constant equal to 1"
):
model = ManualLinearDynamicQATModel(default_qat_qconfig)
model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear})
model = ManualLinearDynamicQATModel()
model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear})
self.assertEqual(type(model.fc1), nnqatd.Linear)
self.assertEqual(type(model.fc2), nnqatd.Linear)
self.checkObservers(model)
test_only_train_fn(model, self.train_data)
model = convert(model, mapping={nnqatd.Linear: nnqd.Linear})
self.assertEqual(type(model.fc1), nnqd.Linear)
self.assertEqual(type(model.fc2), nnqd.Linear)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
self.checkNoQconfig(model)
def test_defused_embedding_bag_linear(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = DeFusedEmbeddingBagLinear().train()
model = prepare_qat(model, mapping=get_embedding_qat_module_mappings())
self.checkObservers(model)
test_only_train_fn(model, self.embed_linear_data_train)
# make sure activation_post_process is inserted after Linear.
self.assertEqual(type(model.linear.activation_post_process), FusedMovingAvgObsFakeQuantize)
# make sure that Embedding has a noop for activation.
self.assertEqual(type(model.emb.activation_post_process), NoopObserver)
# make sure that FakeQuant zero_points are correct dtype
self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32)
self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32)
model = convert(model, mapping=get_embedding_static_quant_module_mappings())
def checkQuantized(model):
# make sure Embedding is now a QuantizedEmbedding
self.assertEqual(type(model.emb), nn.quantized.Embedding)
# make sure Linear is now a QuantizedLinear
self.assertEqual(type(model.linear), nn.quantized.Linear)
test_only_eval_fn(model, self.embed_data)
self.checkScriptable(model, self.embed_data)
self.checkNoQconfig(model)
checkQuantized(model)
def test_embedding_bag_linear(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ManualEmbeddingBagLinear().train()
model = prepare_qat(model, mapping=get_embedding_qat_module_mappings())
self.checkObservers(model)
test_only_train_fn(model, self.embed_linear_data_train)
# make sure not activation_post_process is inserted for EmbeddingBag
self.assertFalse(hasattr(model, "activation_post_process"))
# make sure that FakeQuant zero_points are correct dtype
self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32)
self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32)
model = convert(model, mapping=get_embedding_static_quant_module_mappings())
def checkQuantized(model):
# Make sure EmbeddingBag is now a quantized EmbeddingBag.
self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag)
# Also test that Linear has been quantized.
self.assertTrue(type(model.linear), nnq.Linear)
test_only_eval_fn(model, self.embed_data)
self.checkScriptable(model, self.embed_data)
self.checkNoQconfig(model)
checkQuantized(model)
model = ManualEmbeddingBagLinear()
def test_train_save_load_eval(self):
r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict
During eval, we first call prepare_qat and conver on the model and then load the state_dict
and compare results against original model
"""
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = TwoLayerLinearModel()
model = torch.ao.quantization.QuantWrapper(model)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
model = prepare_qat(model)
fq_state_dict = model.state_dict()
test_only_train_fn(model, self.train_data)
model = convert(model)
quant_state_dict = model.state_dict()
x = torch.rand(2, 5, dtype=torch.float)
ref = model(x)
# Create model again for eval. Check result using quantized state_dict
model = TwoLayerLinearModel()
model = torch.ao.quantization.QuantWrapper(model)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
torch.ao.quantization.prepare_qat(model, inplace=True)
new_state_dict = model.state_dict()
# Check to make sure the model after prepare_qat has the same state_dict as original.
self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys()))
torch.ao.quantization.convert(model, inplace=True)
model.eval()
model.load_state_dict(quant_state_dict)
out = model(x)
self.assertEqual(ref, out)
# Check model created using prepare has same state dict as quantized state_dict
model = TwoLayerLinearModel()
model.eval()
model = torch.ao.quantization.QuantWrapper(model)
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
torch.ao.quantization.prepare(model, inplace=True)
torch.ao.quantization.convert(model, inplace=True)
self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys()))
model.eval()
model.load_state_dict(quant_state_dict)
out = model(x)
self.assertEqual(ref, out)
@override_qengines
def test_forward_hooks_preserved(self):
r"""Test QAT on preserving pre forward and post forward hooks of original model
"""
qengine = torch.backends.quantized.engine
model = QuantStubModel()
counter = {
'pre_forwards': 0,
'forwards': 0,
}
def fw_pre_hook(h_module, input):
counter['pre_forwards'] += 1
def fw_hook(h_module, input, output):
counter['forwards'] += 1
model.fc.register_forward_pre_hook(fw_pre_hook)
model.fc.register_forward_hook(fw_hook)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
model = prepare_qat(model)
def checkHooksIsPresent(model, before_convert=True):
forward_hooks = 1
if before_convert:
self.assertEqual(len(model.quant._forward_hooks.values()), 1,
"Quantization observer hook has disappeared")
forward_hooks = 2
self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values())
self.assertObjectIn(fw_hook, model.fc._forward_hooks.values())
self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1,
"Extra pre forward hooks have appeared on a layer")
self.assertEqual(len(model.fc._forward_hooks.values()), forward_hooks,
"Extra post forward hooks have appeared on a layer")
checkHooksIsPresent(model, True)
x = torch.rand(2, 5, dtype=torch.float)
model(x)
torch.ao.quantization.convert(model, inplace=True)
checkHooksIsPresent(model, False)
def test_add_scalar_uses_input_qparams(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.ff = torch.ao.nn.quantized.FloatFunctional()
def forward(self, x):
x = self.quant(x)
x = self.ff.add_scalar(x, 1.0)
return x
m = M()
m.qconfig = torch.ao.quantization.default_qconfig
mp = torch.ao.quantization.prepare_qat(m)
mp(torch.randn(4, 4))
mq = torch.ao.quantization.convert(mp)
res = mq(torch.randn(4, 4))
eps = 1e-5
self.assertTrue(torch.abs(mq.quant.scale - res.q_scale()) < eps)
def test_mul_scalar_uses_input_qparams(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.ff = torch.ao.nn.quantized.FloatFunctional()
def forward(self, x):
x = self.quant(x)
x = self.ff.mul_scalar(x, 2.0)
return x
m = M()
m.qconfig = torch.ao.quantization.default_qconfig
mp = torch.ao.quantization.prepare_qat(m)
mp(torch.randn(4, 4))
mq = torch.ao.quantization.convert(mp)
res = mq(torch.randn(4, 4))
eps = 1e-5
self.assertTrue(torch.abs(mq.quant.scale * 2 - res.q_scale()) < eps)
@override_qengines
def test_qat_embedding_bag_errors(self):
default_qat_qconfig = get_default_qat_qconfig(torch.backends.quantized.engine)
# Test constructor parameters checks here.
with self.assertRaisesRegex(AssertionError,
"qconfig must be provided for QAT module"):
nnqat.EmbeddingBag(10, 5, qconfig=None)
with self.assertRaisesRegex(AssertionError,
"Embedding Bag weights requires a qscheme of " +
"torch.per_channel_affine_float_qparams"):
nnqat.EmbeddingBag(10, 5, qconfig=default_qat_qconfig)
# Test from_float checks here.
embed = nn.Embedding(10, 5)
with self.assertRaisesRegex(AssertionError,
"qat.EmbeddingBag.from_float only works for EmbeddingBag"):
nnqat.EmbeddingBag.from_float(embed)
embed_bag = nn.EmbeddingBag(10, 5)
with self.assertRaisesRegex(AssertionError,
"Input float module must have qconfig defined"):
nnqat.EmbeddingBag.from_float(embed_bag)
embed_bag.qconfig = None
with self.assertRaisesRegex(AssertionError,
"Input float module must have a valid qconfig"):
nnqat.EmbeddingBag.from_float(embed_bag)
embed_bag.qconfig = default_qat_qconfig
with self.assertRaisesRegex(AssertionError,
"Embedding Bag weights requires a qscheme of " +
"torch.per_channel_affine_float_qparams"):
nnqat.EmbeddingBag.from_float(embed_bag)
def test_embedding_qat_qconfig_equal(self):
# Embedding QAT uses a NoopObserver class for activation,
# and a FakeQuant for weight, make sure that qconfig comparison
# functions properly for a mix of partial function and class in
# qconfig.
model = ManualEmbeddingBagLinear().train()
model = prepare_qat(model)
self.assertTrue(qconfig_equals(model.emb.qconfig,
default_embedding_qat_qconfig))
class TestQuantizeEagerQATNumerics(QuantizationTestCase):
def _test_activation_convert_numerics_impl(self, Act, data):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.act = Act()
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.act(x)
x = self.dequant(x)
return x
m = M().train()
m.qconfig = default_qat_qconfig
m = prepare_qat(m)
before_convert = m(data)
m = convert(m)
after_convert = m(data)
self.assertEqual(before_convert, after_convert)
def test_fixed_qparam_ops(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigmoid = torch.nn.Sigmoid()
self.hardsigmoid = torch.nn.Hardsigmoid()
self.tanh = torch.nn.Tanh()
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.sigmoid(x)
x = self.hardsigmoid(x)
x = self.tanh(x)
x = self.dequant(x)
return x
m = M().train()
m.qconfig = default_qat_qconfig
m = prepare_qat(m)
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize)
data = torch.randn(1, 3, 2, 4)
before_convert = m(data)
m = convert(m)
after_convert = m(data)
self.assertEqual(before_convert, after_convert)
# make sure activation post process is removed
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
# verify fake quant module is removd
self.assertFalse(hasattr(getattr(m, attr), 'activation_post_process'))
# verify that hooks are removed
self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
# make sure no fake quantize module is inserted for eval mode
def checkNoFQModule(m):
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
self.assertFalse(hasattr(getattr(m, attr), "activation_post_process"))
self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
m = M().eval()
m.qconfig = default_qconfig
m = prepare(m)
checkNoFQModule(m)
m = convert(m)
checkNoFQModule(m)
def test_leaky_relu(self):
data = torch.randn(1, 3, 2, 4)
self._test_activation_convert_numerics_impl(nn.LeakyReLU, data)
def test_relu(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(x)
return x
m = M().train()
m.qconfig = default_qconfig
m = prepare_qat(m)
# make sure no activation_post_process is inserted for relu
self.assertFalse(hasattr(m, "activation_post_process"))
m = convert(m)
# make sure ReLU module is not changed
self.assertTrue(type(m.relu), nn.ReLU)
@given(batch_size=st.integers(2, 4),
input_channels_per_group=st.sampled_from([2, 3, 4]),
height=st.integers(5, 10),
width=st.integers(5, 10),
output_channels_per_group=st.sampled_from([2, 3]),
groups=st.integers(1, 3),
kernel_h=st.integers(1, 3),
kernel_w=st.integers(1, 3),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 1),
padding_mode=st.sampled_from(['zeros', 'circular']),
use_relu=st.booleans(),
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
momentum=st.sampled_from([0.1, 0.2, 0.3]),
freeze_bn=st.booleans(),
zero_gamma=st.booleans(),
has_bias=st.booleans(),
use_slow_fusion=st.booleans())
def test_conv_bn_relu(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
padding_mode,
use_relu,
eps,
momentum,
freeze_bn,
zero_gamma,
has_bias,
use_slow_fusion,
):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
dilation_h = dilation_w = dilation
conv_op = Conv2d(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
has_bias,
padding_mode
).to(dtype=torch.double)
bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.double)
relu_op = ReLU()
cls = ConvBnReLU2d if use_relu else ConvBn2d
qat_op = cls(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
has_bias,
padding_mode,
eps,
momentum,
freeze_bn=True,
qconfig=default_qat_qconfig
).to(dtype=torch.double)
qat_op._enable_slow_path_for_better_numerical_stability = use_slow_fusion
# the approximate fusion will not work if bn.weight has 0
if zero_gamma and use_slow_fusion:
torch.nn.init.zeros_(qat_op.bn.weight)
qat_op.apply(torch.ao.quantization.disable_fake_quant)
if freeze_bn:
qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
else:
qat_op.apply(torch.ao.nn.intrinsic.qat.update_bn_stats)
# align inputs and internal parameters
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
conv_op.weight = torch.nn.Parameter(qat_op.weight.detach())
if has_bias:
conv_op.bias = torch.nn.Parameter(qat_op.bias.detach())
bn_op.running_mean = qat_op.bn.running_mean.clone()
bn_op.running_var = qat_op.bn.running_var.clone()
bn_op.weight = torch.nn.Parameter(qat_op.bn.weight.detach())
bn_op.bias = torch.nn.Parameter(qat_op.bn.bias.detach())
def compose(functions):
# functions are reversed for natural reading order
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
if not use_relu:
def relu_op(x): # noqa: F811
return x
if freeze_bn:
def ref_op(x):
x = conv_op(x)
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \
(bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \
.reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
x = relu_op(x)
return x
else:
ref_op = compose([conv_op, bn_op, relu_op])
input_clone = input.clone().detach().requires_grad_()
for i in range(2):
result_ref = ref_op(input)
result_actual = qat_op(input_clone)
self.assertEqual(result_ref, result_actual)
# backward
dout = torch.randn(result_ref.size(), dtype=torch.double)
loss = (result_ref - dout).sum()
loss.backward()
input_grad_ref = input.grad.cpu()
weight_grad_ref = conv_op.weight.grad.cpu()
gamma_grad_ref = bn_op.weight.grad.cpu()
beta_grad_ref = bn_op.bias.grad.cpu()
running_mean_ref = bn_op.running_mean
running_var_ref = bn_op.running_var
num_batches_tracked_ref = bn_op.num_batches_tracked
loss = (result_actual - dout).sum()
loss.backward()
input_grad_actual = input_clone.grad.cpu()
weight_grad_actual = qat_op.weight.grad.cpu()
gamma_grad_actual = qat_op.bn.weight.grad.cpu()
beta_grad_actual = qat_op.bn.bias.grad.cpu()
running_mean_actual = qat_op.bn.running_mean
running_var_actual = qat_op.bn.running_var
num_batches_tracked_actual = qat_op.bn.num_batches_tracked
precision = 1e-10
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)
@given(batch_size=st.integers(2, 4),
input_channels_per_group=st.sampled_from([2, 3, 4]),
height=st.integers(5, 10),
width=st.integers(5, 10),
output_channels_per_group=st.sampled_from([2, 3]),
groups=st.integers(1, 3),
kernel_h=st.integers(1, 3),
kernel_w=st.integers(1, 3),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 1),
padding_mode=st.sampled_from(['zeros', 'circular']),
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
momentum=st.sampled_from([0.1, 0.2, 0.3]),
freeze_bn=st.booleans(),
bias=st.booleans())
def test_conv_bn_folded_vs_unfolded(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
padding_mode,
eps,
momentum,
freeze_bn,
bias,
):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
dilation_h = dilation_w = dilation
qat_op = ConvBn2d(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
bias, # bias
padding_mode,
eps,
momentum,
freeze_bn=freeze_bn,
qconfig=default_qat_qconfig
).to(dtype=torch.double)
qat_ref_op = _ReferenceConvBn2d(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
bias, # bias
padding_mode,
eps,
momentum,
freeze_bn=freeze_bn,
qconfig=default_qat_qconfig
).to(dtype=torch.double)
qat_op.apply(torch.ao.quantization.disable_fake_quant)
qat_ref_op.apply(torch.ao.quantization.disable_fake_quant)
# align inputs and internal parameters
qat_ref_op.weight = torch.nn.Parameter(qat_op.weight.detach().clone())
qat_ref_op.running_mean = qat_op.bn.running_mean.clone()
qat_ref_op.running_var = qat_op.bn.running_var.clone()
qat_ref_op.gamma = torch.nn.Parameter(qat_op.bn.weight.detach().clone())
qat_ref_op.beta = torch.nn.Parameter(qat_op.bn.bias.detach().clone())
if qat_op.bias is not None:
qat_ref_op.bias = torch.nn.Parameter(qat_op.bias.detach().clone())
lr = 0.01
qat_op_optim = torch.optim.SGD(qat_op.parameters(), lr=lr)
qat_ref_op_optim = torch.optim.SGD(qat_ref_op.parameters(), lr=lr)
for i in range(5):
# make sure that calling model.train() does not override the
# bn freeze setting
qat_op.train()
qat_ref_op.train()
qat_op_optim.zero_grad()
qat_ref_op_optim.zero_grad()
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
input_clone = input.clone().detach().requires_grad_()
if i > 2:
qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
qat_ref_op.freeze_bn_stats()
if i > 3:
qat_op.apply(torch.ao.quantization.disable_observer)
qat_ref_op.apply(torch.ao.quantization.disable_observer)
result_ref = qat_ref_op(input)
result_actual = qat_op(input_clone)
self.assertEqual(result_ref, result_actual)
# backward
dout = torch.randn(result_ref.size(), dtype=torch.double) + 10.0
loss = (result_ref - dout).sum()
loss.backward()
input_grad_ref = input.grad.cpu()
weight_grad_ref = qat_ref_op.weight.grad.cpu()
gamma_grad_ref = qat_ref_op.gamma.grad.cpu()
beta_grad_ref = qat_ref_op.beta.grad.cpu()
running_mean_ref = qat_ref_op.running_mean
running_var_ref = qat_ref_op.running_var
num_batches_tracked_ref = qat_ref_op.num_batches_tracked
loss = (result_actual - dout).sum()
loss.backward()
input_grad_actual = input_clone.grad.cpu()
weight_grad_actual = qat_op.weight.grad.cpu()
gamma_grad_actual = qat_op.bn.weight.grad.cpu()
beta_grad_actual = qat_op.bn.bias.grad.cpu()
running_mean_actual = qat_op.bn.running_mean
running_var_actual = qat_op.bn.running_var
num_batches_tracked_actual = qat_op.bn.num_batches_tracked
precision = 1e-5
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)
qat_op_optim.step()
qat_ref_op_optim.step()
@override_qengines
def test_linear_bn_numerics(self):
qengine = torch.backends.quantized.engine
m_ref = nn.Sequential(
nn.Linear(4, 4),
nn.BatchNorm1d(4),
)
m_ref_copy = copy.deepcopy(m_ref)
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']])
qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
m_ref_copy[0].qconfig = qconfig
m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
# without fake_quants, fused QAT module should match fp32 module
m.apply(torch.ao.quantization.disable_fake_quant)
data = torch.randn(4, 4)
r1 = m_ref(data)
r2 = m(data)
self.assertTrue(torch.allclose(r1, r2))
@skipIfNoXNNPACK
@override_qengines
def test_linear_bn_symm_numerics(self):
qengine = torch.backends.quantized.engine
if qengine != "qnnpack":
return # Only qnnpack support symmetric quantization
m_ref = nn.Sequential(
nn.Linear(4, 4),
nn.BatchNorm1d(4),
)
m_ref_copy = copy.deepcopy(m_ref)
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']])
qconfig = default_symmetric_qnnpack_qat_qconfig
m_ref_copy[0].qconfig = qconfig
m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
# without fake_quants, fused QAT module should match fp32 module
m.apply(torch.ao.quantization.disable_fake_quant)
data = torch.randn(4, 4)
r1 = m_ref(data)
r2 = m(data)
self.assertTrue(torch.allclose(r1, r2))
@override_qengines
def test_linear_bn_workflow(self):
qengine = torch.backends.quantized.engine
m = nn.Sequential(
QuantStub(),
nn.Linear(4, 4),
nn.BatchNorm1d(4),
)
data = torch.randn(4, 4)
m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
m = torch.ao.quantization.fuse_modules_qat(m, [['1', '2']])
mp = prepare_qat(m)
mp(data)
mq = convert(mp)
self.assertTrue(type(mq[1]) == nnq.Linear)
self.assertTrue(type(mq[2]) == nn.Identity)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_quantization.py TESTNAME\n\n"
"instead.")