Numeric Suite: Enable QAT support

Enable shadow module support for QAT. In this case, the reference model is the model being trained, but with fake-quants turned off.

Differential Revision: [D26068393](https://our.internmc.facebook.com/intern/diff/D26068393/)

[ghstack-poisoned]
This commit is contained in:
Raghu Krishnamoorthi 2021-02-17 19:58:10 -08:00
parent 27d89057f8
commit dade06f082
2 changed files with 138 additions and 0 deletions

View File

@ -6,7 +6,9 @@ from torch.quantization import (
QuantStub,
convert,
default_qconfig,
get_default_qat_qconfig,
prepare,
prepare_qat,
quantize,
quantize_dynamic,
)
@ -17,6 +19,8 @@ from torch.quantization._numeric_suite import (
compare_model_outputs,
compare_model_stub,
compare_weights,
prepare_qat_model_with_stubs,
get_logger_dict,
)
from torch.testing._internal.common_quantization import (
AnnotatedConvBnReLUModel,
@ -30,6 +34,17 @@ from torch.testing._internal.common_quantization import (
)
from torch.testing._internal.common_quantized import override_qengines
class TestTwoLayerQATModel(nn.Module):
def __init__(self):
super(TestTwoLayerQATModel, self).__init__()
self.quant = QuantStub()
self.linear1 = nn.Linear(5, 10)
self.linear2 = nn.Linear(10, 20)
self.dequant = DeQuantStub()
def forward(self, x):
return self.dequant(self.linear2(self.linear1(self.quant(x))))
class SubModule(torch.nn.Module):
def __init__(self):
@ -86,6 +101,24 @@ class ModelWithFunctionals(torch.nn.Module):
class TestEagerModeNumericSuite(QuantizationTestCase):
@override_qengines
def test_prepare_qat_stubs(self):
qengine = torch.backends.quantized.engine
m = TestTwoLayerQATModel()
x = torch.rand(3, 5)
qconfig = get_default_qat_qconfig(qengine)
m.qconfig = qconfig
print(m)
prepared_m = prepare_qat(m)
print(prepared_m)
prepared_m(x)
prepare_qat_model_with_stubs(prepared_m, module_swap_list=[nn.qat.Linear])
print(prepared_m)
prepared_m(x)
ob_dict = get_logger_dict(prepared_m)
print(ob_dict)
return
@override_qengines
def test_compare_weights_conv_static(self):
r"""Compare the weights of float and static quantized conv layer"""

View File

@ -4,6 +4,7 @@ import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
from torch.quantization import prepare
from typing import Dict, List, Optional, Any, Union, Callable, Set
import torch.quantization as tq
from .quantization_mappings import (
get_default_compare_output_module_list,
@ -206,6 +207,17 @@ class OutputLogger(Logger):
self.stats["tensor_val"].append(x)
return x
class ErrorLogger(Logger):
def __init__(self):
super().__init__()
def forward(self, x ,y):
print('Error', torch.norm(x-y))
self.stats['signal'] = torch.norm(y)
self.stats['error'] = torch.norm(x-y)
return
def _convert_tuple_to_list(t: Any) -> Any:
return list(_convert_tuple_to_list(x) for x in t) if type(t) is tuple else t
@ -293,6 +305,35 @@ class Shadow(nn.Module):
self.logger(output, shadow_output)
return output
class QATShadow(nn.Module):
r"""Shadow module attaches the float module to its matching quantized module
as the shadow. Then it uses Logger module to process the outputs of both
modules.
Args:
float_module: float QAT module that we need to shadow
Logger: type of logger used to process the output of
float_module with and without fake quant.
ShadowLogger or custom loggers can be used.
"""
def __init__(self, module, Logger):
super(QATShadow, self).__init__()
self.orig_module = module
self.logger = Logger()
def forward(self, x):
output = self.orig_module(x)
with torch.no_grad():
# Save original state
self.orig_module.apply(tq.disable_fake_quant)
self.orig_module.apply(tq.disable_observer)
shadow_output = self.orig_module(x)
# TODO: Restore it back, currently state is not preserved
self.orig_module.apply(tq.enable_fake_quant)
self.orig_module.apply(tq.enable_observer)
self.logger(output, shadow_output)
return output
def prepare_model_with_stubs(
float_module: nn.Module, q_module: nn.Module,
@ -346,6 +387,70 @@ def _is_identical_module_type(mod1, mod2):
def _logger_forward_hook(self, input, output):
r"""Forward hook that calls logger on the input and output of fake quant
"""
with torch.no_grad():
if self.fake_quant_enabled[0]:
# Pass first input of fake quant to logger
return self.logger(input[0], output)
def _register_logger_hook(module):
assert hasattr(module, 'activation_post_process'), \
'Expect activation_post_process attribute already attached to the module'
assert hasattr(module, 'logger'), \
'Expect logger attribute already attached to the module'
return module.register_forward_hook(_logger_forward_hook)
def _insert_logger(module, Logger):
for name, mod in module.named_children():
print(name)
if isinstance(mod, tq.FakeQuantizeBase):
mod.add_module('logger', Logger())
# Register logger as the first entry in the hook list
# All post forward hooks are preserved and will be executed after the logger
handle = _register_logger_hook(mod)
mod._forward_hooks.move_to_end(handle.id, last=False)
else:
_insert_logger(mod, Logger)
return
def prepare_qat_model_with_stubs(float_module, module_swap_list, Logger=ErrorLogger):
r"""Prepare the model by attaching the float module to its matching quantized
module as the shadow if the float module type is in module_swap_list.
Example usage:
prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
q_model(data)
ob_dict = get_logger_dict(q_model)
Args:
float_module: float module used to generate the q_module
q_module: module quantized from float_module
module_swap_list: list of float module types to attach the shadow
Logger: type of logger to be used in shadow module to process the outputs of
quantized module and its float shadow module
"""
_insert_logger(float_module, Logger)
float_module_children = {}
reassign = {}
for name, mod in float_module.named_children():
float_module_children[name] = mod
float_mod = float_module_children[name]
if type(float_mod) not in module_swap_list:
prepare_qat_model_with_stubs(float_mod, module_swap_list, ErrorLogger)
if type(float_mod) in module_swap_list:
reassign[name] = QATShadow(float_mod, Logger)
for key, value in reassign.items():
float_module._modules[key] = value
def compare_model_stub(
float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
*data, logger_cls=ShadowLogger