pytorch/test/quantization/test_qat_module.py
Mike Ruberry 13120bf677 Updates assertEqual to require atol and rtol, removes positional atol (#38872)
Summary:
This updates assertEqual and assertEqual-like functions to either require both or neither of atol and rtol be specified. This should improve clarity around handling precision in the test suite, and it allows us to remove the legacy positional atol argument from assertEqual. In addition, the "message" kwarg is replace with a kwarg-only "msg" argument whose name is consistent with unittest's assertEqual argument.

In the future we could make "msg" an optional third positional argument to be more consistent with unittest's assertEqual, but requiring it be specified should be clear, and we can easily update the signature to make "msg" an optional positional argument in the future, too.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38872

Differential Revision: D21740237

Pulled By: mruberry

fbshipit-source-id: acbc027aa1d7877a49664d94db9a5fff91a07042
2020-05-27 06:31:07 -07:00

510 lines
22 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import math
import torch
import torch.nn as nn
from torch.nn import Conv2d, BatchNorm2d, ReLU, init
from torch.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
from torch.nn.modules.utils import _pair
from torch.quantization.qconfig import default_qat_qconfig
import torch.backends.mkldnn
from torch.testing._internal.common_utils import TestCase
from hypothesis import given
from hypothesis import strategies as st
import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled()
from functools import reduce
class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
"""
Conv-BN fusion implemented with explicit folding. Useful
to verify numerical equivalency with non-folded version.
"""
def __init__(self,
# ConvNd args
in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding,
groups,
bias,
padding_mode,
# BatchNormNd args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
stride, padding, dilation, transposed,
output_padding, groups, False, padding_mode)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.eps = eps
self.momentum = momentum
self.freeze_bn = freeze_bn if self.training else True
self.num_features = out_channels
self.gamma = nn.Parameter(torch.Tensor(out_channels))
self.beta = nn.Parameter(torch.Tensor(out_channels))
self.affine = True
self.track_running_stats = True
self.register_buffer('running_mean', torch.zeros(out_channels))
self.register_buffer('running_var', torch.ones(out_channels))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
self.activation_post_process = self.qconfig.activation()
self.weight_fake_quant = self.qconfig.weight()
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_bn_parameters()
def reset_running_stats(self):
self.running_mean.zero_()
self.running_var.fill_(1)
self.num_batches_tracked.zero_()
def reset_bn_parameters(self):
self.reset_running_stats()
init.uniform_(self.gamma)
init.zeros_(self.beta)
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def reset_parameters(self):
super(_ReferenceConvBnNd, self).reset_parameters()
# A hack to avoid resetting on undefined parameters
if hasattr(self, 'gamma'):
self.reset_bn_parameters()
def update_bn_stats(self):
self.freeze_bn = False
return self
def freeze_bn_stats(self):
self.freeze_bn = True
return self
def _forward(self, input):
# exponential_average_factor is self.momentum set to
# (when it is available) only so that if gets updated
# in ONNX graph when this node is exported to ONNX.
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and not self.freeze_bn and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# we use running statistics from the previous batch, so this is an
# approximation of the approach mentioned in the whitepaper, but we only
# need to do one convolution in this case instead of two
running_std = torch.sqrt(self.running_var + self.eps)
scale_factor = self.gamma / running_std
scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1])
conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight))
if self.training and not self.freeze_bn:
# recovering original conv to get original batch_mean and batch_var
if self.bias is not None:
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1])
else:
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
batch_mean = torch.mean(conv_orig, dim=[0, 2, 3])
batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False)
n = float(conv_orig.numel() / conv_orig.size()[1])
unbiased_batch_var = batch_var * (n / (n - 1))
batch_rstd = torch.ones_like(batch_var, memory_format=torch.contiguous_format) / torch.sqrt(batch_var + self.eps)
conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + \
(self.beta - self.gamma * batch_rstd * batch_mean).reshape([1, -1, 1, 1])
self.running_mean = exponential_average_factor * batch_mean.detach() + \
(1 - exponential_average_factor) * self.running_mean
self.running_var = exponential_average_factor * unbiased_batch_var.detach() + \
(1 - exponential_average_factor) * self.running_var
else:
if self.bias is None:
conv = conv + (self.beta - self.gamma * self.running_mean /
running_std).reshape([1, -1, 1, 1])
else:
conv = conv + (self.gamma * (self.bias - self.running_mean) / running_std + self.beta).reshape([1, -1, 1, 1])
return conv
def extra_repr(self):
# TODO(jerryzh): extend
return super(_ReferenceConvBnNd, self).extra_repr()
def forward(self, input):
return self.activation_post_process(self._forward(input))
@classmethod
def from_float(cls, mod, qconfig=None):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
if not qconfig:
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
qconfig = mod.qconfig
conv, bn = mod[0], mod[1]
qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
conv.stride, conv.padding, conv.dilation,
conv.groups, conv.bias is not None,
conv.padding_mode,
bn.eps, bn.momentum,
False,
qconfig)
qat_convbn.weight = conv.weight
qat_convbn.bias = conv.bias
qat_convbn.gamma = bn.weight
qat_convbn.beta = bn.bias
qat_convbn.running_mean = bn.running_mean
qat_convbn.running_var = bn.running_var
qat_convbn.num_batches_tracked = bn.num_batches_tracked
return qat_convbn
class _ReferenceConvBn2d(_ReferenceConvBnNd, nn.Conv2d):
_FLOAT_MODULE = torch.nn.intrinsic.ConvBn2d
def __init__(self,
# ConvNd args
in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=None,
padding_mode='zeros',
# BatchNorm2d args
# num_features: out_channels
eps=1e-05, momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
_ReferenceConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, False, _pair(0), groups, bias, padding_mode,
eps, momentum, freeze_bn, qconfig)
class TestQATModule(TestCase):
@given(batch_size=st.integers(2, 4),
input_channels_per_group=st.sampled_from([2, 3, 4]),
height=st.integers(5, 10),
width=st.integers(5, 10),
output_channels_per_group=st.sampled_from([2, 3]),
groups=st.integers(1, 3),
kernel_h=st.integers(1, 3),
kernel_w=st.integers(1, 3),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 1),
padding_mode=st.sampled_from(['zeros', 'circular']),
use_relu=st.booleans(),
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
momentum=st.sampled_from([0.1, 0.2, 0.3]),
freeze_bn=st.booleans())
def test_conv_bn_relu(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
padding_mode,
use_relu,
eps,
momentum,
freeze_bn
):
# **** WARNING: This is used to temporarily disable MKL-DNN convolution due
# to a bug: https://github.com/pytorch/pytorch/issues/23825
# Once this bug is fixed, this context manager as well as its callsites
# should be removed!
with torch.backends.mkldnn.flags(enabled=False):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
dilation_h = dilation_w = dilation
conv_op = Conv2d(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
False, # No bias
padding_mode
).to(dtype=torch.double)
bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.double)
relu_op = ReLU()
cls = ConvBnReLU2d if use_relu else ConvBn2d
qat_op = cls(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
None, # bias
padding_mode,
eps,
momentum,
freeze_bn=True,
qconfig=default_qat_qconfig
).to(dtype=torch.double)
qat_op.apply(torch.quantization.disable_fake_quant)
if freeze_bn:
qat_op.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
else:
qat_op.apply(torch.nn.intrinsic.qat.update_bn_stats)
# align inputs and internal parameters
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
conv_op.weight = torch.nn.Parameter(qat_op.weight.detach())
bn_op.running_mean = qat_op.bn.running_mean.clone()
bn_op.running_var = qat_op.bn.running_var.clone()
bn_op.weight = torch.nn.Parameter(qat_op.bn.weight.detach())
bn_op.bias = torch.nn.Parameter(qat_op.bn.bias.detach())
def compose(functions):
# functions are reversed for natural reading order
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
if not use_relu:
def relu_op(x):
return x
if freeze_bn:
def ref_op(x):
x = conv_op(x)
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \
(bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \
.reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
x = relu_op(x)
return x
else:
ref_op = compose([conv_op, bn_op, relu_op])
input_clone = input.clone().detach().requires_grad_()
for i in range(2):
result_ref = ref_op(input)
result_actual = qat_op(input_clone)
self.assertEqual(result_ref, result_actual)
# backward
dout = torch.randn(result_ref.size(), dtype=torch.double)
loss = (result_ref - dout).sum()
loss.backward()
input_grad_ref = input.grad.cpu()
weight_grad_ref = conv_op.weight.grad.cpu()
gamma_grad_ref = bn_op.weight.grad.cpu()
beta_grad_ref = bn_op.bias.grad.cpu()
running_mean_ref = bn_op.running_mean
running_var_ref = bn_op.running_var
num_batches_tracked_ref = bn_op.num_batches_tracked
loss = (result_actual - dout).sum()
loss.backward()
input_grad_actual = input_clone.grad.cpu()
weight_grad_actual = qat_op.weight.grad.cpu()
gamma_grad_actual = qat_op.bn.weight.grad.cpu()
beta_grad_actual = qat_op.bn.bias.grad.cpu()
running_mean_actual = qat_op.bn.running_mean
running_var_actual = qat_op.bn.running_var
num_batches_tracked_actual = qat_op.bn.num_batches_tracked
precision = 1e-10
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)
@given(batch_size=st.integers(2, 4),
input_channels_per_group=st.sampled_from([2, 3, 4]),
height=st.integers(5, 10),
width=st.integers(5, 10),
output_channels_per_group=st.sampled_from([2, 3]),
groups=st.integers(1, 3),
kernel_h=st.integers(1, 3),
kernel_w=st.integers(1, 3),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 1),
padding_mode=st.sampled_from(['zeros', 'circular']),
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
momentum=st.sampled_from([0.1, 0.2, 0.3]),
freeze_bn=st.booleans(),
bias=st.booleans())
def test_conv_bn_folded_vs_unfolded(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
padding_mode,
eps,
momentum,
freeze_bn,
bias,
):
# **** WARNING: This is used to temporarily disable MKL-DNN convolution due
# to a bug: https://github.com/pytorch/pytorch/issues/23825
# Once this bug is fixed, this context manager as well as its callsites
# should be removed!
with torch.backends.mkldnn.flags(enabled=False):
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
dilation_h = dilation_w = dilation
qat_op = ConvBn2d(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
bias, # bias
padding_mode,
eps,
momentum,
freeze_bn=freeze_bn,
qconfig=default_qat_qconfig
).to(dtype=torch.double)
qat_ref_op = _ReferenceConvBn2d(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
bias, # bias
padding_mode,
eps,
momentum,
freeze_bn=freeze_bn,
qconfig=default_qat_qconfig
).to(dtype=torch.double)
qat_op.apply(torch.quantization.disable_fake_quant)
qat_ref_op.apply(torch.quantization.disable_fake_quant)
# align inputs and internal parameters
qat_ref_op.weight = torch.nn.Parameter(qat_op.weight.detach().clone())
qat_ref_op.running_mean = qat_op.bn.running_mean.clone()
qat_ref_op.running_var = qat_op.bn.running_var.clone()
qat_ref_op.gamma = torch.nn.Parameter(qat_op.bn.weight.detach().clone())
qat_ref_op.beta = torch.nn.Parameter(qat_op.bn.bias.detach().clone())
if qat_op.bias is not None:
qat_ref_op.bias = torch.nn.Parameter(qat_op.bias.detach().clone())
lr = 0.01
qat_op_optim = torch.optim.SGD(qat_op.parameters(), lr=lr)
qat_ref_op_optim = torch.optim.SGD(qat_ref_op.parameters(), lr=lr)
for i in range(5):
# make sure that calling model.train() does not override the
# bn freeze setting
qat_op.train()
qat_ref_op.train()
qat_op_optim.zero_grad()
qat_ref_op_optim.zero_grad()
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
input_clone = input.clone().detach().requires_grad_()
if i > 2:
qat_op.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
qat_ref_op.freeze_bn_stats()
if i > 3:
qat_op.apply(torch.quantization.disable_observer)
qat_ref_op.apply(torch.quantization.disable_observer)
result_ref = qat_ref_op(input)
result_actual = qat_op(input_clone)
self.assertEqual(result_ref, result_actual)
# backward
dout = torch.randn(result_ref.size(), dtype=torch.double) + 10.0
loss = (result_ref - dout).sum()
loss.backward()
input_grad_ref = input.grad.cpu()
weight_grad_ref = qat_ref_op.weight.grad.cpu()
gamma_grad_ref = qat_ref_op.gamma.grad.cpu()
beta_grad_ref = qat_ref_op.beta.grad.cpu()
running_mean_ref = qat_ref_op.running_mean
running_var_ref = qat_ref_op.running_var
num_batches_tracked_ref = qat_ref_op.num_batches_tracked
loss = (result_actual - dout).sum()
loss.backward()
input_grad_actual = input_clone.grad.cpu()
weight_grad_actual = qat_op.weight.grad.cpu()
gamma_grad_actual = qat_op.bn.weight.grad.cpu()
beta_grad_actual = qat_op.bn.bias.grad.cpu()
running_mean_actual = qat_op.bn.running_mean
running_var_actual = qat_op.bn.running_var
num_batches_tracked_actual = qat_op.bn.num_batches_tracked
precision = 1e-5
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)
qat_op_optim.step()
qat_ref_op_optim.step()