mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Remove test_quantizer.py and reuse one of its test in test_quantization.py (#27269)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27269 Remove `test_quantizer.py`, add and rewrite one of the tests in `test_quantizer` in `test_quantization.py` The conv test is removed for now since conv pattern is still broken, we'll add another test later ghstack-source-id: 92869823 Test Plan: python test/test_quantization.py Imported from OSS Differential Revision: D18182916 fbshipit-source-id: 325b5d8e877228d6a513e3ddf52c974479250d42
This commit is contained in:
parent
dfe7b25eaf
commit
1c436ded44
|
|
@ -110,7 +110,7 @@ test_python_nn() {
|
|||
}
|
||||
|
||||
test_python_all_except_nn() {
|
||||
time python test/run_test.py --exclude nn --verbose --bring-to-front quantization quantized quantized_tensor quantized_nn_mods quantizer
|
||||
time python test/run_test.py --exclude nn --verbose --bring-to-front quantization quantized quantized_tensor quantized_nn_mods
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ TESTS = [
|
|||
'quantized',
|
||||
'quantized_tensor',
|
||||
'quantized_nn_mods',
|
||||
'quantizer',
|
||||
'sparse',
|
||||
'torch',
|
||||
'type_info',
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from common_quantization import AnnotatedTwoLayerLinearModel, AnnotatedNestedMod
|
|||
AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel
|
||||
|
||||
from jit_utils import _tmp_donotuse_dont_inline_everything
|
||||
from jit_utils import get_forward
|
||||
|
||||
from hypothesis import given
|
||||
from hypothesis import strategies as st
|
||||
|
|
@ -695,10 +696,49 @@ class GraphModePostTrainingQuantTest(QuantizationTestCase):
|
|||
[self.calib_data],
|
||||
inplace=False)
|
||||
result_eager = model_eager(self.calib_data[0][0])
|
||||
torch._C._jit_pass_quant_fusion(model_script._c._get_module('fc1')._get_method('forward').graph)
|
||||
result_script = model_script._c._get_method('forward')(self.calib_data[0][0])
|
||||
result_script = get_forward(model_script._c)(self.calib_data[0][0])
|
||||
self.assertEqual(result_eager, result_script)
|
||||
|
||||
@unittest.skip("quantization for inlined linear is not working right now")
|
||||
def test_nested(self):
|
||||
# Eager mode
|
||||
eager_model = AnnotatedNestedModel()
|
||||
# default_per_channel_qconfig is not scriptable right now,
|
||||
# temporarily change to default_qconfig until default_per_channel_qconfig is fixed
|
||||
eager_model.sub2.fc1.qconfig = default_qconfig
|
||||
|
||||
# Graph mode
|
||||
script_model = NestedModel()
|
||||
# Copy weights for eager_model
|
||||
script_model.sub1.fc.weight = torch.nn.Parameter(eager_model.sub1.fc.weight.detach())
|
||||
script_model.sub1.fc.bias = torch.nn.Parameter(eager_model.sub1.fc.bias.detach())
|
||||
script_model.sub2.fc1.weight = torch.nn.Parameter(eager_model.sub2.fc1.module.weight.detach())
|
||||
script_model.sub2.fc1.bias = torch.nn.Parameter(eager_model.sub2.fc1.module.bias.detach())
|
||||
script_model.sub2.fc2.weight = torch.nn.Parameter(eager_model.sub2.fc2.weight.detach())
|
||||
script_model.sub2.fc2.bias = torch.nn.Parameter(eager_model.sub2.fc2.bias.detach())
|
||||
script_model.fc3.weight = torch.nn.Parameter(eager_model.fc3.module.weight.detach())
|
||||
script_model.fc3.bias = torch.nn.Parameter(eager_model.fc3.module.bias.detach())
|
||||
print(eager_model(self.calib_data[0][0]))
|
||||
# Quantize eager module
|
||||
quantized_eager_model = quantize(eager_model, test_only_eval_fn, self.calib_data)
|
||||
|
||||
qconfig_dict = {
|
||||
'sub2.fc1': default_qconfig,
|
||||
'fc3': default_qconfig
|
||||
}
|
||||
quantized_script_model = quantize_script(
|
||||
torch.jit.script(script_model),
|
||||
qconfig_dict,
|
||||
test_only_eval_fn,
|
||||
[self.calib_data],
|
||||
inplace=False)
|
||||
|
||||
eager_result = quantized_eager_model(self.calib_data[0][0])
|
||||
print(get_forward(quantized_script_model._c._get_module('fc3')).graph)
|
||||
script_result = get_forward(quantized_script_model._c)(self.calib_data[0][0])
|
||||
print(eager_result, script_result)
|
||||
self.assertEqual(eager_result, script_result)
|
||||
|
||||
|
||||
class FunctionalModuleTest(QuantizationTestCase):
|
||||
# Histogram Observers are slow, so have no-deadline to ensure test doesn't time out
|
||||
|
|
|
|||
|
|
@ -1,164 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
import torch.jit
|
||||
from jit_utils import _tmp_donotuse_dont_inline_everything
|
||||
from torch._jit_internal import Optional
|
||||
import torch.nn as nn
|
||||
from common_utils import TestCase, run_tests
|
||||
from common_quantization import NestedModel, AnnotatedNestedModel
|
||||
from torch.quantization import QuantStub, DeQuantStub, \
|
||||
quantize, default_eval_fn, QConfig
|
||||
|
||||
class Observer(torch.nn.Module):
|
||||
__annotations__ = {'scale' : Optional[torch.Tensor], 'zero_point': Optional[torch.Tensor]}
|
||||
|
||||
def __init__(self):
|
||||
super(Observer, self).__init__()
|
||||
self.dtype = torch.quint8
|
||||
self.qscheme = torch.per_tensor_affine
|
||||
self.scale, self.zero_point = None, None
|
||||
|
||||
def forward(self, x):
|
||||
self.scale = torch.tensor([2.0])
|
||||
self.zero_point = torch.tensor([3])
|
||||
return x
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self):
|
||||
return self.scale, self.zero_point
|
||||
|
||||
class WeightObserver(Observer):
|
||||
def __init__(self):
|
||||
super(WeightObserver, self).__init__()
|
||||
self.dtype = torch.qint8
|
||||
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
@unittest.skip("temporarily disable the test")
|
||||
class QuantizerTestCase(TestCase):
|
||||
@_tmp_donotuse_dont_inline_everything
|
||||
def test_default(self):
|
||||
class TestM(nn.Module):
|
||||
def __init__(self, qconfig):
|
||||
super(TestM, self).__init__()
|
||||
self.conv = nn.Conv2d(3, 1, 3).float()
|
||||
self.conv.weight.data.fill_(1.0)
|
||||
self.conv.bias.data.fill_(0.01)
|
||||
self.qconfig = qconfig
|
||||
self.quant = QuantStub()
|
||||
self.dequant = DeQuantStub()
|
||||
|
||||
def forward(self, x):
|
||||
return self.dequant(self.conv(self.quant(x)))
|
||||
|
||||
class TestScriptM(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(TestScriptM, self).__init__()
|
||||
self.conv = nn.Conv2d(3, 1, 3).float()
|
||||
self.conv.bias.data.fill_(0.01)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
y = self.conv(x)
|
||||
return y
|
||||
|
||||
# Test Data
|
||||
data = [(torch.randn(10, 3, 10, 10, dtype=torch.float), 1)]
|
||||
|
||||
# Eager mode
|
||||
fake_qconfig = QConfig(activation=Observer, weight=WeightObserver)
|
||||
eager_module = TestM(fake_qconfig)
|
||||
# Script mode
|
||||
script_module = TestScriptM()
|
||||
script_module.conv.weight = torch.nn.Parameter(eager_module.conv.weight.detach())
|
||||
quantized_eager_module = quantize(eager_module, default_eval_fn, data)
|
||||
|
||||
def get_forward(m):
|
||||
return m._c._get_method('forward')
|
||||
# TODO: test jit.script as well
|
||||
ScriptedObserver = torch.jit.script(Observer())
|
||||
ScriptedWeightObserver = torch.jit.script(WeightObserver())
|
||||
qconfig_dict = {
|
||||
'':
|
||||
QConfig(
|
||||
activation=ScriptedObserver._c,
|
||||
weight=ScriptedWeightObserver._c)
|
||||
}
|
||||
torch._C._jit_pass_insert_observers(script_module._c,
|
||||
"forward",
|
||||
qconfig_dict)
|
||||
# Run ScriptM Model and Collect statistics
|
||||
get_forward(script_module)(data[0][0])
|
||||
|
||||
# Insert quantize and dequantize calls
|
||||
script_module._c = torch._C._jit_pass_insert_quant_dequant(script_module._c, "forward")
|
||||
# Note that observer modules are not removed right now
|
||||
torch._C._jit_pass_quant_fusion(script_module._c._get_method('forward').graph)
|
||||
get_forward(script_module)(data[0][0])
|
||||
eager_result = quantized_eager_module(data[0][0])
|
||||
script_result = get_forward(script_module)(data[0][0])
|
||||
self.assertEqual(eager_result, script_result)
|
||||
|
||||
@_tmp_donotuse_dont_inline_everything
|
||||
def test_qconfig_dict(self):
|
||||
data = [(torch.randn(10, 5, dtype=torch.float) * 20, 1)]
|
||||
|
||||
# Eager mode
|
||||
qconfig = QConfig(activation=Observer, weight=WeightObserver)
|
||||
eager_module = AnnotatedNestedModel()
|
||||
eager_module.fc3.qconfig = qconfig
|
||||
eager_module.sub2.fc1.qconfig = qconfig
|
||||
# Assign weights
|
||||
eager_module.sub1.fc.weight.data.fill_(1.0)
|
||||
eager_module.sub2.fc1.module.weight.data.fill_(1.0)
|
||||
eager_module.sub2.fc2.weight.data.fill_(1.0)
|
||||
eager_module.fc3.module.weight.data.fill_(1.0)
|
||||
|
||||
script_module = torch.jit.script(NestedModel())
|
||||
# Copy weights for eager_module
|
||||
script_module.sub1.fc.weight = eager_module.sub1.fc.weight
|
||||
script_module.sub2.fc1.weight = eager_module.sub2.fc1.module.weight
|
||||
script_module.sub2.fc2.weight = eager_module.sub2.fc2.weight
|
||||
script_module.fc3.weight = eager_module.fc3.module.weight
|
||||
|
||||
# Quantize eager module
|
||||
quantized_eager_module = quantize(eager_module, default_eval_fn, data)
|
||||
|
||||
def get_forward(m):
|
||||
return m._c._get_method('forward')
|
||||
|
||||
# Quantize script_module
|
||||
torch._C._jit_pass_constant_propagation(get_forward(script_module).graph)
|
||||
|
||||
ScriptedObserver = torch.jit.script(Observer())
|
||||
ScriptedWeightObserver = torch.jit.script(WeightObserver())
|
||||
scripted_qconfig = QConfig(
|
||||
activation=ScriptedObserver._c,
|
||||
weight=ScriptedWeightObserver._c)
|
||||
qconfig_dict = {
|
||||
'sub2.fc1': scripted_qconfig,
|
||||
'fc3': scripted_qconfig
|
||||
}
|
||||
torch._C._jit_pass_insert_observers(script_module._c,
|
||||
"forward",
|
||||
qconfig_dict)
|
||||
|
||||
# Run script_module and Collect statistics
|
||||
get_forward(script_module)(data[0][0])
|
||||
|
||||
# Insert quantize and dequantize calls
|
||||
script_module._c = torch._C._jit_pass_insert_quant_dequant(script_module._c, "forward")
|
||||
# Note that observer modules are not removed right now
|
||||
torch._C._jit_pass_quant_fusion(script_module._c._get_method('forward').graph)
|
||||
get_forward(script_module)(data[0][0])
|
||||
eager_result = quantized_eager_module(data[0][0])
|
||||
script_result = get_forward(script_module)(data[0][0])
|
||||
self.assertEqual(eager_result, script_result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
Loading…
Reference in New Issue
Block a user