diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 6741d7d6bf4..0b41c519587 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -110,7 +110,7 @@ test_python_nn() { } test_python_all_except_nn() { - time python test/run_test.py --exclude nn --verbose --bring-to-front quantization quantized quantized_tensor quantized_nn_mods quantizer + time python test/run_test.py --exclude nn --verbose --bring-to-front quantization quantized quantized_tensor quantized_nn_mods assert_git_not_dirty } diff --git a/test/run_test.py b/test/run_test.py index 631fa28a33c..52bfff8df05 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -47,7 +47,6 @@ TESTS = [ 'quantized', 'quantized_tensor', 'quantized_nn_mods', - 'quantizer', 'sparse', 'torch', 'type_info', diff --git a/test/test_quantization.py b/test/test_quantization.py index 8b172d6b29d..64f4303b7ef 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -31,6 +31,7 @@ from common_quantization import AnnotatedTwoLayerLinearModel, AnnotatedNestedMod AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel from jit_utils import _tmp_donotuse_dont_inline_everything +from jit_utils import get_forward from hypothesis import given from hypothesis import strategies as st @@ -695,10 +696,49 @@ class GraphModePostTrainingQuantTest(QuantizationTestCase): [self.calib_data], inplace=False) result_eager = model_eager(self.calib_data[0][0]) - torch._C._jit_pass_quant_fusion(model_script._c._get_module('fc1')._get_method('forward').graph) - result_script = model_script._c._get_method('forward')(self.calib_data[0][0]) + result_script = get_forward(model_script._c)(self.calib_data[0][0]) self.assertEqual(result_eager, result_script) + @unittest.skip("quantization for inlined linear is not working right now") + def test_nested(self): + # Eager mode + eager_model = AnnotatedNestedModel() + # default_per_channel_qconfig is not scriptable right now, + # temporarily change to default_qconfig until default_per_channel_qconfig is fixed + eager_model.sub2.fc1.qconfig = default_qconfig + + # Graph mode + script_model = NestedModel() + # Copy weights for eager_model + script_model.sub1.fc.weight = torch.nn.Parameter(eager_model.sub1.fc.weight.detach()) + script_model.sub1.fc.bias = torch.nn.Parameter(eager_model.sub1.fc.bias.detach()) + script_model.sub2.fc1.weight = torch.nn.Parameter(eager_model.sub2.fc1.module.weight.detach()) + script_model.sub2.fc1.bias = torch.nn.Parameter(eager_model.sub2.fc1.module.bias.detach()) + script_model.sub2.fc2.weight = torch.nn.Parameter(eager_model.sub2.fc2.weight.detach()) + script_model.sub2.fc2.bias = torch.nn.Parameter(eager_model.sub2.fc2.bias.detach()) + script_model.fc3.weight = torch.nn.Parameter(eager_model.fc3.module.weight.detach()) + script_model.fc3.bias = torch.nn.Parameter(eager_model.fc3.module.bias.detach()) + print(eager_model(self.calib_data[0][0])) + # Quantize eager module + quantized_eager_model = quantize(eager_model, test_only_eval_fn, self.calib_data) + + qconfig_dict = { + 'sub2.fc1': default_qconfig, + 'fc3': default_qconfig + } + quantized_script_model = quantize_script( + torch.jit.script(script_model), + qconfig_dict, + test_only_eval_fn, + [self.calib_data], + inplace=False) + + eager_result = quantized_eager_model(self.calib_data[0][0]) + print(get_forward(quantized_script_model._c._get_module('fc3')).graph) + script_result = get_forward(quantized_script_model._c)(self.calib_data[0][0]) + print(eager_result, script_result) + self.assertEqual(eager_result, script_result) + class FunctionalModuleTest(QuantizationTestCase): # Histogram Observers are slow, so have no-deadline to ensure test doesn't time out diff --git a/test/test_quantizer.py b/test/test_quantizer.py deleted file mode 100644 index 3df63f017de..00000000000 --- a/test/test_quantizer.py +++ /dev/null @@ -1,164 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -import unittest -import torch.jit -from jit_utils import _tmp_donotuse_dont_inline_everything -from torch._jit_internal import Optional -import torch.nn as nn -from common_utils import TestCase, run_tests -from common_quantization import NestedModel, AnnotatedNestedModel -from torch.quantization import QuantStub, DeQuantStub, \ - quantize, default_eval_fn, QConfig - -class Observer(torch.nn.Module): - __annotations__ = {'scale' : Optional[torch.Tensor], 'zero_point': Optional[torch.Tensor]} - - def __init__(self): - super(Observer, self).__init__() - self.dtype = torch.quint8 - self.qscheme = torch.per_tensor_affine - self.scale, self.zero_point = None, None - - def forward(self, x): - self.scale = torch.tensor([2.0]) - self.zero_point = torch.tensor([3]) - return x - - @torch.jit.export - def calculate_qparams(self): - return self.scale, self.zero_point - -class WeightObserver(Observer): - def __init__(self): - super(WeightObserver, self).__init__() - self.dtype = torch.qint8 - -@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, - " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" - " with instruction set support avx2 or newer.") -@unittest.skip("temporarily disable the test") -class QuantizerTestCase(TestCase): - @_tmp_donotuse_dont_inline_everything - def test_default(self): - class TestM(nn.Module): - def __init__(self, qconfig): - super(TestM, self).__init__() - self.conv = nn.Conv2d(3, 1, 3).float() - self.conv.weight.data.fill_(1.0) - self.conv.bias.data.fill_(0.01) - self.qconfig = qconfig - self.quant = QuantStub() - self.dequant = DeQuantStub() - - def forward(self, x): - return self.dequant(self.conv(self.quant(x))) - - class TestScriptM(torch.jit.ScriptModule): - def __init__(self): - super(TestScriptM, self).__init__() - self.conv = nn.Conv2d(3, 1, 3).float() - self.conv.bias.data.fill_(0.01) - - @torch.jit.script_method - def forward(self, x): - y = self.conv(x) - return y - - # Test Data - data = [(torch.randn(10, 3, 10, 10, dtype=torch.float), 1)] - - # Eager mode - fake_qconfig = QConfig(activation=Observer, weight=WeightObserver) - eager_module = TestM(fake_qconfig) - # Script mode - script_module = TestScriptM() - script_module.conv.weight = torch.nn.Parameter(eager_module.conv.weight.detach()) - quantized_eager_module = quantize(eager_module, default_eval_fn, data) - - def get_forward(m): - return m._c._get_method('forward') - # TODO: test jit.script as well - ScriptedObserver = torch.jit.script(Observer()) - ScriptedWeightObserver = torch.jit.script(WeightObserver()) - qconfig_dict = { - '': - QConfig( - activation=ScriptedObserver._c, - weight=ScriptedWeightObserver._c) - } - torch._C._jit_pass_insert_observers(script_module._c, - "forward", - qconfig_dict) - # Run ScriptM Model and Collect statistics - get_forward(script_module)(data[0][0]) - - # Insert quantize and dequantize calls - script_module._c = torch._C._jit_pass_insert_quant_dequant(script_module._c, "forward") - # Note that observer modules are not removed right now - torch._C._jit_pass_quant_fusion(script_module._c._get_method('forward').graph) - get_forward(script_module)(data[0][0]) - eager_result = quantized_eager_module(data[0][0]) - script_result = get_forward(script_module)(data[0][0]) - self.assertEqual(eager_result, script_result) - - @_tmp_donotuse_dont_inline_everything - def test_qconfig_dict(self): - data = [(torch.randn(10, 5, dtype=torch.float) * 20, 1)] - - # Eager mode - qconfig = QConfig(activation=Observer, weight=WeightObserver) - eager_module = AnnotatedNestedModel() - eager_module.fc3.qconfig = qconfig - eager_module.sub2.fc1.qconfig = qconfig - # Assign weights - eager_module.sub1.fc.weight.data.fill_(1.0) - eager_module.sub2.fc1.module.weight.data.fill_(1.0) - eager_module.sub2.fc2.weight.data.fill_(1.0) - eager_module.fc3.module.weight.data.fill_(1.0) - - script_module = torch.jit.script(NestedModel()) - # Copy weights for eager_module - script_module.sub1.fc.weight = eager_module.sub1.fc.weight - script_module.sub2.fc1.weight = eager_module.sub2.fc1.module.weight - script_module.sub2.fc2.weight = eager_module.sub2.fc2.weight - script_module.fc3.weight = eager_module.fc3.module.weight - - # Quantize eager module - quantized_eager_module = quantize(eager_module, default_eval_fn, data) - - def get_forward(m): - return m._c._get_method('forward') - - # Quantize script_module - torch._C._jit_pass_constant_propagation(get_forward(script_module).graph) - - ScriptedObserver = torch.jit.script(Observer()) - ScriptedWeightObserver = torch.jit.script(WeightObserver()) - scripted_qconfig = QConfig( - activation=ScriptedObserver._c, - weight=ScriptedWeightObserver._c) - qconfig_dict = { - 'sub2.fc1': scripted_qconfig, - 'fc3': scripted_qconfig - } - torch._C._jit_pass_insert_observers(script_module._c, - "forward", - qconfig_dict) - - # Run script_module and Collect statistics - get_forward(script_module)(data[0][0]) - - # Insert quantize and dequantize calls - script_module._c = torch._C._jit_pass_insert_quant_dequant(script_module._c, "forward") - # Note that observer modules are not removed right now - torch._C._jit_pass_quant_fusion(script_module._c._get_method('forward').graph) - get_forward(script_module)(data[0][0]) - eager_result = quantized_eager_module(data[0][0]) - script_result = get_forward(script_module)(data[0][0]) - self.assertEqual(eager_result, script_result) - -if __name__ == '__main__': - run_tests()