mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28807 `FAIL: test_numerical_consistency_per_channel (_main_.TestFakeQuantizePerChannel)` This test is failing consistently on master, we can't find a clean blame. ghstack-source-id: 92763176 Test Plan: CI Differential Revision: D18181496 fbshipit-source-id: 5948af06c4cb7dea9a8db1366deb7c12f6ec1c72
165 lines
6.5 KiB
Python
165 lines
6.5 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import unittest
|
|
import torch.jit
|
|
from jit_utils import _tmp_donotuse_dont_inline_everything
|
|
from torch._jit_internal import Optional
|
|
import torch.nn as nn
|
|
from common_utils import TestCase, run_tests
|
|
from common_quantization import NestedModel, AnnotatedNestedModel
|
|
from torch.quantization import QuantStub, DeQuantStub, \
|
|
quantize, default_eval_fn, QConfig
|
|
|
|
class Observer(torch.nn.Module):
|
|
__annotations__ = {'scale' : Optional[torch.Tensor], 'zero_point': Optional[torch.Tensor]}
|
|
|
|
def __init__(self):
|
|
super(Observer, self).__init__()
|
|
self.dtype = torch.quint8
|
|
self.qscheme = torch.per_tensor_affine
|
|
self.scale, self.zero_point = None, None
|
|
|
|
def forward(self, x):
|
|
self.scale = torch.tensor([2.0])
|
|
self.zero_point = torch.tensor([3])
|
|
return x
|
|
|
|
@torch.jit.export
|
|
def calculate_qparams(self):
|
|
return self.scale, self.zero_point
|
|
|
|
class WeightObserver(Observer):
|
|
def __init__(self):
|
|
super(WeightObserver, self).__init__()
|
|
self.dtype = torch.qint8
|
|
|
|
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
|
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
|
" with instruction set support avx2 or newer.")
|
|
@unittest.skip("temporarily disable the test")
|
|
class QuantizerTestCase(TestCase):
|
|
@_tmp_donotuse_dont_inline_everything
|
|
def test_default(self):
|
|
class TestM(nn.Module):
|
|
def __init__(self, qconfig):
|
|
super(TestM, self).__init__()
|
|
self.conv = nn.Conv2d(3, 1, 3).float()
|
|
self.conv.weight.data.fill_(1.0)
|
|
self.conv.bias.data.fill_(0.01)
|
|
self.qconfig = qconfig
|
|
self.quant = QuantStub()
|
|
self.dequant = DeQuantStub()
|
|
|
|
def forward(self, x):
|
|
return self.dequant(self.conv(self.quant(x)))
|
|
|
|
class TestScriptM(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(TestScriptM, self).__init__()
|
|
self.conv = nn.Conv2d(3, 1, 3).float()
|
|
self.conv.bias.data.fill_(0.01)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
y = self.conv(x)
|
|
return y
|
|
|
|
# Test Data
|
|
data = [(torch.randn(10, 3, 10, 10, dtype=torch.float), 1)]
|
|
|
|
# Eager mode
|
|
fake_qconfig = QConfig(activation=Observer, weight=WeightObserver)
|
|
eager_module = TestM(fake_qconfig)
|
|
# Script mode
|
|
script_module = TestScriptM()
|
|
script_module.conv.weight = torch.nn.Parameter(eager_module.conv.weight.detach())
|
|
quantized_eager_module = quantize(eager_module, default_eval_fn, data)
|
|
|
|
def get_forward(m):
|
|
return m._c._get_method('forward')
|
|
# TODO: test jit.script as well
|
|
ScriptedObserver = torch.jit.script(Observer())
|
|
ScriptedWeightObserver = torch.jit.script(WeightObserver())
|
|
qconfig_dict = {
|
|
'':
|
|
QConfig(
|
|
activation=ScriptedObserver._c,
|
|
weight=ScriptedWeightObserver._c)
|
|
}
|
|
torch._C._jit_pass_insert_observers(script_module._c,
|
|
"forward",
|
|
qconfig_dict)
|
|
# Run ScriptM Model and Collect statistics
|
|
get_forward(script_module)(data[0][0])
|
|
|
|
# Insert quantize and dequantize calls
|
|
script_module._c = torch._C._jit_pass_insert_quant_dequant(script_module._c, "forward")
|
|
# Note that observer modules are not removed right now
|
|
torch._C._jit_pass_quant_fusion(script_module._c._get_method('forward').graph)
|
|
get_forward(script_module)(data[0][0])
|
|
eager_result = quantized_eager_module(data[0][0])
|
|
script_result = get_forward(script_module)(data[0][0])
|
|
self.assertEqual(eager_result, script_result)
|
|
|
|
@_tmp_donotuse_dont_inline_everything
|
|
def test_qconfig_dict(self):
|
|
data = [(torch.randn(10, 5, dtype=torch.float) * 20, 1)]
|
|
|
|
# Eager mode
|
|
qconfig = QConfig(activation=Observer, weight=WeightObserver)
|
|
eager_module = AnnotatedNestedModel()
|
|
eager_module.fc3.qconfig = qconfig
|
|
eager_module.sub2.fc1.qconfig = qconfig
|
|
# Assign weights
|
|
eager_module.sub1.fc.weight.data.fill_(1.0)
|
|
eager_module.sub2.fc1.module.weight.data.fill_(1.0)
|
|
eager_module.sub2.fc2.weight.data.fill_(1.0)
|
|
eager_module.fc3.module.weight.data.fill_(1.0)
|
|
|
|
script_module = torch.jit.script(NestedModel())
|
|
# Copy weights for eager_module
|
|
script_module.sub1.fc.weight = eager_module.sub1.fc.weight
|
|
script_module.sub2.fc1.weight = eager_module.sub2.fc1.module.weight
|
|
script_module.sub2.fc2.weight = eager_module.sub2.fc2.weight
|
|
script_module.fc3.weight = eager_module.fc3.module.weight
|
|
|
|
# Quantize eager module
|
|
quantized_eager_module = quantize(eager_module, default_eval_fn, data)
|
|
|
|
def get_forward(m):
|
|
return m._c._get_method('forward')
|
|
|
|
# Quantize script_module
|
|
torch._C._jit_pass_constant_propagation(get_forward(script_module).graph)
|
|
|
|
ScriptedObserver = torch.jit.script(Observer())
|
|
ScriptedWeightObserver = torch.jit.script(WeightObserver())
|
|
scripted_qconfig = QConfig(
|
|
activation=ScriptedObserver._c,
|
|
weight=ScriptedWeightObserver._c)
|
|
qconfig_dict = {
|
|
'sub2.fc1': scripted_qconfig,
|
|
'fc3': scripted_qconfig
|
|
}
|
|
torch._C._jit_pass_insert_observers(script_module._c,
|
|
"forward",
|
|
qconfig_dict)
|
|
|
|
# Run script_module and Collect statistics
|
|
get_forward(script_module)(data[0][0])
|
|
|
|
# Insert quantize and dequantize calls
|
|
script_module._c = torch._C._jit_pass_insert_quant_dequant(script_module._c, "forward")
|
|
# Note that observer modules are not removed right now
|
|
torch._C._jit_pass_quant_fusion(script_module._c._get_method('forward').graph)
|
|
get_forward(script_module)(data[0][0])
|
|
eager_result = quantized_eager_module(data[0][0])
|
|
script_result = get_forward(script_module)(data[0][0])
|
|
self.assertEqual(eager_result, script_result)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|