pytorch/test/test_quantization.py
Raghuraman Krishnamoorthi 17c1b2c715 Relax scale to prevent saturation in conv/linear. Add test to verify precision of numerics of quantized model with updated observer. This test catches errors in (#25667)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25667

Relax scale and zero-point for activations to ensure that fbgemm implementations of conv and linear do not saturate due to 16 bit intermediate accumulation.

Add test to verify precision of numerics of quantized model with updated observer. This test catches errors in
handling layouts for quantized ops in addition to saturation/quantization errors.
ghstack-source-id: 89587942

Test Plan:
buck test caffe2/test:quantized -- 'test_float_quant_compare \(test_quantized_models\.ModelNumerics\)' --print-passing-details

Passes when SQNR > 35 dB

buck test caffe2/test:quantization -- 'test_minmax_observer \(test_quantization\.ObserverTest\)' --print-passing-details
Passes with additional coverage for observer changes

Differential Revision: D17140498

fbshipit-source-id: 42c58e726bb0b0f51890590ee2525428f9a8d24e
2019-09-06 17:18:01 -07:00

764 lines
28 KiB
Python

import unittest
import torch
import torch.nn as nn
import torch.nn.quantized as nnq
import torch.nn._intrinsic as nni
import torch.nn._intrinsic.quantized as nniq
import torch.nn._intrinsic.qat as nniqat
from torch.quantization import \
QConfig_dynamic, default_weight_observer, \
quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules, \
quantize_dynamic, default_qconfig, default_qat_qconfig, \
default_dynamic_qconfig, MinMaxObserver, QuantWrapper
from common_utils import run_tests, tempfile
from common_quantization import QuantizationTestCase, SingleLayerLinearModel, \
SkipQuantModel, QuantStubModel, \
ModelForFusion, ManualLinearQATModel, ManualConvLinearQATModel, \
ModForWrapping, \
test_only_eval_fn, test_only_train_fn, \
prepare_dynamic, convert_dynamic, SingleLayerLinearDynamicModel, \
TwoLayerLinearModel, NestedModel, ResNetBase, LSTMDynamicModel
from common_quantization import AnnotatedTwoLayerLinearModel, AnnotatedNestedModel, \
AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel
from hypothesis import given
from hypothesis import strategies as st
import io
import copy
@unittest.skipIf(
not torch.fbgemm_is_cpu_supported(),
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.",
)
class PostTrainingQuantTest(QuantizationTestCase):
def test_single_layer(self):
r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
to nnq.Linear which is the quantized version of the module
"""
model = SingleLayerLinearModel()
prepare(model)
# Check if observers and quant/dequant nodes are inserted
self.checkNoPrepModules(model)
self.checkHasPrepModules(model.fc1)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model)
self.checkHasPrepModules(model.fc1)
self.checkWrappedQuantizedLinear(model.fc1)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(SingleLayerLinearModel(), test_only_eval_fn,
self.calib_data)
checkQuantized(model)
def test_two_layers(self):
r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one
`fc2`, and `fc1`is not quantized
"""
model = AnnotatedTwoLayerLinearModel()
prepare(model)
self.checkNoPrepModules(model)
self.checkObservers(model)
self.checkNoPrepModules(model.fc1)
self.checkHasPrepModules(model.fc2)
test_only_eval_fn(model, self.calib_data)
convert(model)
def checkQuantized(model):
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.fc1)
self.checkHasPrepModules(model.fc2)
self.assertEqual(type(model.fc1), torch.nn.Linear)
self.checkWrappedQuantizedLinear(model.fc2)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedTwoLayerLinearModel(), test_only_eval_fn,
self.calib_data)
checkQuantized(model)
def test_nested1(self):
r"""Test quantization for nested model, top level 'fc3' and
'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
"""
model = AnnotatedNestedModel()
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkNoPrepModules(model.sub2)
self.checkHasPrepModules(model.sub2.fc1)
self.checkNoPrepModules(model.sub2.fc2)
self.checkHasPrepModules(model.fc3)
prepare(model)
checkPrepModules(model, True)
test_only_eval_fn(model, self.calib_data)
convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkLinear(model.sub1.fc)
self.checkWrappedQuantizedLinear(model.fc3)
self.checkWrappedQuantizedLinear(model.sub2.fc1)
self.checkLinear(model.sub2.fc2)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedNestedModel(), test_only_eval_fn,
self.calib_data)
checkQuantized(model)
def test_nested2(self):
model = AnnotatedSubNestedModel()
prepare(model)
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkHasPrepModules(model.sub2)
self.checkNoPrepModules(model.sub2.module.fc1)
self.checkNoPrepModules(model.sub2.module.fc2)
self.checkHasPrepModules(model.fc3)
checkPrepModules(model, True)
test_only_eval_fn(model, self.calib_data)
convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkLinear(model.sub1.fc)
self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
self.checkQuantizedLinear(model.sub2.module.fc1)
self.checkQuantizedLinear(model.sub2.module.fc2)
self.checkWrappedQuantizedLinear(model.fc3)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedSubNestedModel(), test_only_eval_fn,
self.calib_data)
checkQuantized(model)
def test_nested3(self):
r"""More complicated nested test case with child qconfig overrides
parent qconfig
"""
model = AnnotatedCustomConfigNestedModel()
prepare(model)
def checkPrepModules(model, before_calib=False):
if before_calib:
self.checkObservers(model)
self.checkNoPrepModules(model)
self.checkNoPrepModules(model.sub1)
self.checkNoPrepModules(model.sub1.fc)
self.checkNoPrepModules(model.sub1.relu)
self.checkNoPrepModules(model.sub2)
self.checkHasPrepModules(model.sub2.fc1)
self.checkHasPrepModules(model.sub2.fc2)
self.checkHasPrepModules(model.fc3)
checkPrepModules(model, True)
test_only_eval_fn(model, self.calib_data)
convert(model)
def checkQuantized(model):
checkPrepModules(model)
self.checkWrappedQuantizedLinear(model.sub2.fc1)
self.checkWrappedQuantizedLinear(model.sub2.fc2)
self.checkWrappedQuantizedLinear(model.fc3)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(AnnotatedCustomConfigNestedModel(), test_only_eval_fn,
self.calib_data)
checkQuantized(model)
def test_skip_quant(self):
r"""The case when we want to skip quantizing some layers
"""
model = SkipQuantModel()
prepare(model)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
convert(model)
def checkQuantized(model):
self.checkLinear(model.fc)
self.checkQuantDequant(model.sub)
self.checkQuantizedLinear(model.sub.module.fc1)
self.checkQuantizedLinear(model.sub.module.fc2)
self.assertEqual(type(model.sub.module.relu), nnq.ReLU)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(SkipQuantModel(), test_only_eval_fn, self.calib_data)
checkQuantized(model)
def test_manual(self):
r"""User inserts QuantStub and DeQuantStub in model code
and call the quantization utility functions.
"""
model = QuantStubModel()
# propagate the qconfig of parents to children, model is changed
# inplace
prepare(model)
self.checkObservers(model)
test_only_eval_fn(model, self.calib_data)
convert(model)
def checkQuantized(model):
self.assertEqual(type(model.fc), nnq.Linear)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize(QuantStubModel(), test_only_eval_fn, self.calib_data)
checkQuantized(model)
def test_resnet_base(self):
r"""Test quantization for bottleneck topology used in resnet/resnext
and add coverage for conversion of average pool and float functional
"""
model = ResNetBase().float().eval()
model = QuantWrapper(model)
model.qconfig = default_qconfig
fuse_list = [['module.conv1', 'module.bn1', 'module.relu1']]
fuse_modules(model, fuse_list)
prepare(model)
self.checkObservers(model)
test_only_eval_fn(model, self.img_data)
convert(model)
def checkQuantized(model):
self.assertEqual(type(model.module.conv1), nn._intrinsic.quantized.ConvReLU2d)
self.assertEqual(type(model.module.myop), nn.quantized.QFunctional)
self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d)
test_only_eval_fn(model, self.img_data)
checkQuantized(model)
@unittest.skipIf(
not torch.fbgemm_is_cpu_supported(),
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.",
)
class PostTrainingDynamicQuantTest(QuantizationTestCase):
def test_single_layer(self):
r"""Dynamic Quantize SingleLayerLinearDynamicModel which has one Linear module,
make sure it is swapped to nnqd.Linear which is the quantized version of
the module
"""
model = SingleLayerLinearDynamicModel().eval()
qconfig_dict = {
'': default_dynamic_qconfig
}
prepare_dynamic(model, qconfig_dict)
convert_dynamic(model)
def checkQuantized(model):
self.checkDynamicQuantizedLinear(model.fc1)
checkQuantized(model)
# test one line API
model = quantize_dynamic(SingleLayerLinearDynamicModel().eval(),
qconfig_dict)
checkQuantized(model)
def test_two_layers(self):
r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one
`fc2`, and `fc1`is not quantized
"""
model = TwoLayerLinearModel().eval()
qconfig_dict = {
'fc2': default_dynamic_qconfig
}
prepare_dynamic(model, qconfig_dict)
convert_dynamic(model)
def checkQuantized(model):
self.assertEqual(type(model.fc1), torch.nn.Linear)
self.checkDynamicQuantizedLinear(model.fc2)
checkQuantized(model)
# test one line API
model = quantize_dynamic(TwoLayerLinearModel().eval(), qconfig_dict)
checkQuantized(model)
def test_nested1(self):
r"""Test quantization for nested model, top level 'fc3' and
'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
"""
model = NestedModel().eval()
qconfig_dict = {
'fc3': default_dynamic_qconfig,
'sub2.fc1': default_dynamic_qconfig
}
prepare_dynamic(model, qconfig_dict)
convert_dynamic(model)
def checkQuantized(model):
self.checkLinear(model.sub1.fc)
self.checkDynamicQuantizedLinear(model.fc3)
self.checkDynamicQuantizedLinear(model.sub2.fc1)
self.checkLinear(model.sub2.fc2)
checkQuantized(model)
# test one line API
model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
checkQuantized(model)
def test_nested2(self):
r"""Another test case for quantized, we will quantize all submodules
of submodule sub2
"""
model = NestedModel().eval()
qconfig_dict = {
'fc3': default_dynamic_qconfig,
'sub2': default_dynamic_qconfig
}
prepare_dynamic(model, qconfig_dict)
convert_dynamic(model)
def checkQuantized(model):
self.checkLinear(model.sub1.fc)
self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
self.checkDynamicQuantizedLinear(model.sub2.fc1)
self.checkDynamicQuantizedLinear(model.sub2.fc2)
self.checkDynamicQuantizedLinear(model.fc3)
checkQuantized(model)
# test one line API
model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
checkQuantized(model)
def test_nested3(self):
r"""More complicated nested test case with child qconfig overrides
parent qconfig
"""
model = NestedModel().eval()
custum_options = {
'dtype': torch.quint8,
'qscheme': torch.per_tensor_affine
}
custom_dynamic_qconfig = QConfig_dynamic(weight=default_weight_observer())
qconfig_dynamic_dict = {
'fc3': default_dynamic_qconfig,
'sub2': default_dynamic_qconfig,
'sub2.fc1': custom_dynamic_qconfig
}
prepare_dynamic(model, qconfig_dynamic_dict)
convert_dynamic(model)
def checkQuantized(model):
self.checkDynamicQuantizedLinear(model.sub2.fc1)
self.checkDynamicQuantizedLinear(model.sub2.fc2)
self.checkDynamicQuantizedLinear(model.fc3)
checkQuantized(model)
# test one line API
model = quantize_dynamic(NestedModel().eval(), qconfig_dynamic_dict)
checkQuantized(model)
def test_type_match_rule(self):
r"""Test quantization for nested model, top level 'fc3' and
'fc1' of submodule 'sub2', All 'torch.nn.Linear' modules are quantized
"""
model = NestedModel().eval()
qconfig_dict = {
'fc3': None,
'sub2.fc1': None,
torch.nn.Linear: default_dynamic_qconfig
}
prepare_dynamic(model, qconfig_dict)
test_only_eval_fn(model, self.calib_data)
convert_dynamic(model)
def checkQuantized(model):
self.checkDynamicQuantizedLinear(model.sub1.fc)
self.checkLinear(model.fc3)
self.checkLinear(model.sub2.fc1)
self.checkDynamicQuantizedLinear(model.sub2.fc2)
test_only_eval_fn(model, self.calib_data)
checkQuantized(model)
# test one line API
model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
checkQuantized(model)
def test_quantized_rnn(self):
d_in, d_hid = 2, 2
model = LSTMDynamicModel().eval()
cell = model.lstm
# Replace parameter values s.t. the range of values is exactly
# 255, thus we will have 0 quantization error in the quantized
# GEMM call. This i s for testing purposes.
#
# Note that the current implementation does not support
# accumulation values outside of the range representable by a
# 16 bit integer, instead resulting in a saturated value. We
# must take care that in our test we do not end up with a dot
# product that overflows the int16 range, e.g.
# (255*127+255*127) = 64770. So, we hardcode the test values
# here and ensure a mix of signedness.
vals = [[100, -155],
[100, -155],
[-155, 100],
[-155, 100],
[100, -155],
[-155, 100],
[-155, 100],
[100, -155]]
if isinstance(cell, torch.nn.LSTM):
num_chunks = 4
vals = vals[:d_hid * num_chunks]
cell.weight_ih_l0 = torch.nn.Parameter(
torch.tensor(vals, dtype=torch.float),
requires_grad=False)
cell.weight_hh_l0 = torch.nn.Parameter(
torch.tensor(vals, dtype=torch.float),
requires_grad=False)
ref = copy.deepcopy(cell)
qconfig_dynamic_dict = {
torch.nn.LSTM: default_dynamic_qconfig,
}
default_dynamic_module_mapping = {
torch.nn.LSTM: torch.nn.quantized.dynamic.LSTM,
}
model_int8 = quantize_dynamic(
model, qconfig_dynamic_dict, default_dynamic_module_mapping
)
cell_int8 = model_int8.lstm
assert type(cell_int8) == torch.nn.quantized.dynamic.LSTM, \
'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic'
niter = 10
x = torch.tensor([[100, -155],
[-155, 100],
[100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
h0_vals = [[-155, 100],
[-155, 155],
[100, -155]]
hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)
cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)
if isinstance(ref, torch.nn.LSTM):
hiddens = (hx, cx)
ref_out, ref_hid = ref(x, hiddens)
# Compare int8 quantized to unquantized
output_int8, final_hiddens_int8 = cell_int8(x, hiddens)
torch.testing.assert_allclose(output_int8, ref_out)
self.assertEqual(output_int8, ref_out)
for out, ref in zip(final_hiddens_int8, ref_hid):
torch.testing.assert_allclose(out, ref)
@unittest.skipIf(
not torch.fbgemm_is_cpu_supported(),
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.",
)
class QuantizationAwareTrainingTest(QuantizationTestCase):
def test_manual(self):
model = ManualLinearQATModel()
prepare_qat(model)
self.checkObservers(model)
test_only_train_fn(model, self.train_data)
convert(model)
def checkQuantized(model):
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.fc2), nnq.Linear)
test_only_eval_fn(model, self.calib_data)
self.checkScriptable(model, self.calib_data)
checkQuantized(model)
model = quantize_qat(ManualLinearQATModel(), test_only_train_fn,
self.train_data)
checkQuantized(model)
def test_eval_only_fake_quant(self):
r"""Using FakeQuant in evaluation only mode,
this is useful for estimating accuracy loss when we quantize the
network
"""
model = ManualLinearQATModel()
prepare_qat(model)
self.checkObservers(model)
model.eval()
test_only_eval_fn(model, self.calib_data)
def test_conv_linear(self):
model = ManualConvLinearQATModel()
prepare_qat(model)
self.checkObservers(model)
test_only_train_fn(model, self.img_data)
convert(model)
def checkQuantized(model):
self.assertEqual(type(model.conv), nnq.Conv2d)
self.assertEqual(type(model.fc1), nnq.Linear)
self.assertEqual(type(model.fc2), nnq.Linear)
test_only_eval_fn(model, self.img_data)
self.checkScriptable(model, self.img_data)
checkQuantized(model)
model = ManualConvLinearQATModel()
model = quantize_qat(model, test_only_train_fn, self.img_data)
checkQuantized(model)
class ScriptabilityTest(QuantizationTestCase):
def setUp(self):
self.model_under_test = ModForWrapping(quantized=False)
self.qmodel_under_test = ModForWrapping(quantized=True)
self.qmodel_under_test = self.qmodel_under_test.from_float(
self.model_under_test)
self.x = torch.rand(10)
self.qx = torch.quantize_linear(self.x.to(torch.float), scale=1.0,
zero_point=0, dtype=torch.qint32)
def test_scriptability_serialization(self):
# test serialization of quantized functional modules
with tempfile.TemporaryFile() as f:
torch.save(self.qmodel_under_test, f)
f.seek(0)
loaded = torch.load(f)
self.assertEqual(self.qmodel_under_test.myadd.zero_point, loaded.myadd.zero_point)
state_dict = self.qmodel_under_test.state_dict()
self.assertTrue('myadd.zero_point' in state_dict.keys(),
'zero point not in state dict for functional modules')
x = torch.rand(10, 1, dtype=torch.float)
xq = torch.quantize_linear(x, 1.0, 0, torch.qint8)
self.checkScriptable(self.qmodel_under_test, [(xq, xq)], check_save_load=True)
self.checkScriptable(self.model_under_test, [(xq.dequantize(), xq.dequantize())], check_save_load=True)
@unittest.skipIf(not torch.fbgemm_is_cpu_supported(),
'Quantization requires FBGEMM. FBGEMM does not play'
' well with UBSAN at the moment, so we skip the test if'
' we are in a UBSAN environment.')
class FusionTest(QuantizationTestCase):
def test_fuse_module_train(self):
model = ModelForFusion(default_qat_qconfig).train()
fuse_modules(model, [['conv1', 'bn1', 'relu1'],
['sub1.conv', 'sub1.bn']])
self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
"Fused Conv + BN + Relu first layer")
self.assertEqual(type(model.bn1), torch.nn.Identity,
"Fused Conv + BN + Relu (skipped BN)")
self.assertEqual(type(model.relu1), torch.nn.Identity,
"Fused Conv + BN + Relu (skipped Relu)")
self.assertEqual(type(model.sub1.conv), nni.ConvBn2d,
"Fused submodule Conv + BN")
self.assertEqual(type(model.sub1.bn), torch.nn.Identity,
"Fused submodule Conv + BN (skipped BN)")
self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d,
"Non-fused submodule Conv")
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
"Non-fused submodule ReLU")
prepare_qat(model)
self.checkObservers(model)
def checkQAT(model):
self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.relu1), nn.Identity)
self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d)
self.assertEqual(type(model.sub1.bn), nn.Identity)
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
self.assertEqual(type(model.sub2.relu), nn.ReLU)
checkQAT(model)
test_only_train_fn(model, self.img_data)
convert(model)
def checkQuantized(model):
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.relu1), nn.Identity)
self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
self.assertEqual(type(model.sub1.bn), nn.Identity)
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
self.assertEqual(type(model.sub2.relu), nn.ReLU)
test_only_eval_fn(model, self.img_data)
checkQuantized(model)
model = ModelForFusion(default_qat_qconfig).train()
fuse_modules(model, [['conv1', 'bn1', 'relu1'],
['sub1.conv', 'sub1.bn']])
model = quantize_qat(model, test_only_train_fn, self.img_data)
checkQuantized(model)
def test_fuse_module_eval(self):
model = ModelForFusion(default_qconfig)
model.eval()
fuse_modules(model, [['conv1', 'bn1', 'relu1'] ,
['sub1.conv', 'sub1.bn']])
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
"Fused Conv + BN + Relu first layer (BN is folded)")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
"Fused Conv + BN + Relu (Conv + folded BN only)")
self.assertEqual(type(model.conv1[1]), nn.ReLU,
"Fused Conv + BN + Relu second layer (Relu only)")
self.assertEqual(type(model.bn1), nn.Identity,
"Fused Conv + BN + Relu second layer (Skipped BN)")
self.assertEqual(type(model.relu1), nn.Identity,
"Fused Conv + BN + Relu second layer (Skipped Relu)")
self.assertEqual(type(model.sub1.conv), nn.Conv2d,
"Fused submodule Conv + folded BN")
self.assertEqual(type(model.sub1.bn), nn.Identity,
"Fused submodule (skipped BN)")
self.assertEqual(type(model.sub2.conv), nn.Conv2d,
"Non-fused submodule Conv")
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
"Non-fused submodule ReLU")
prepare(model)
self.checkObservers(model)
test_only_eval_fn(model, self.img_data)
convert(model)
def checkQuantized(model):
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.relu1), nn.Identity)
self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
self.assertEqual(type(model.sub1.bn), nn.Identity)
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
self.assertEqual(type(model.sub2.relu), nn.ReLU)
test_only_eval_fn(model, self.img_data)
checkQuantized(model)
model = ModelForFusion(default_qat_qconfig).eval()
fuse_modules(model, [['conv1', 'bn1', 'relu1'],
['sub1.conv', 'sub1.bn']])
model = quantize(model, test_only_eval_fn, self.img_data)
checkQuantized(model)
class ObserverTest(QuantizationTestCase):
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
reduce_range=st.booleans())
def test_minmax_observer(self, qdtype, qscheme, reduce_range):
# reduce_range cannot be true for symmetric quantization with uint8
if qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric:
reduce_range = False
myobs = MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0])
y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0])
result = myobs(x)
result = myobs(y)
self.assertEqual(result, y)
self.assertEqual(myobs.min_val, 1.0)
self.assertEqual(myobs.max_val, 8.0)
qparams = myobs.calculate_qparams()
if reduce_range:
if qscheme == torch.per_tensor_symmetric:
ref_scale = 0.062745 * 255 / 127
ref_zero_point = 0 if qdtype is torch.qint8 else 128
else:
ref_scale = 0.0313725 * 255 / 127
ref_zero_point = -64 if qdtype is torch.qint8 else 0
else:
if qscheme == torch.per_tensor_symmetric:
ref_scale = 0.062745
ref_zero_point = 0 if qdtype is torch.qint8 else 128
else:
ref_scale = 0.0313725
ref_zero_point = -128 if qdtype is torch.qint8 else 0
self.assertEqual(qparams[1].item(), ref_zero_point)
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
def test_observer_scriptable(self):
obs = torch.quantization.default_observer()()
scripted = torch.jit.script(obs)
x = torch.rand(3, 4)
obs(x)
scripted(x)
self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams())
buf = io.BytesIO()
torch.jit.save(scripted, buf)
buf.seek(0)
loaded = torch.jit.load(buf)
self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams())
if __name__ == '__main__':
run_tests()