mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Numeric Suite: Enable QAT support
Enable shadow module support for QAT. In this case, the reference model is the model being trained, but with fake-quants turned off. Differential Revision: [D26068393](https://our.internmc.facebook.com/intern/diff/D26068393/) [ghstack-poisoned]
This commit is contained in:
parent
27d89057f8
commit
dade06f082
|
|
@ -6,7 +6,9 @@ from torch.quantization import (
|
|||
QuantStub,
|
||||
convert,
|
||||
default_qconfig,
|
||||
get_default_qat_qconfig,
|
||||
prepare,
|
||||
prepare_qat,
|
||||
quantize,
|
||||
quantize_dynamic,
|
||||
)
|
||||
|
|
@ -17,6 +19,8 @@ from torch.quantization._numeric_suite import (
|
|||
compare_model_outputs,
|
||||
compare_model_stub,
|
||||
compare_weights,
|
||||
prepare_qat_model_with_stubs,
|
||||
get_logger_dict,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import (
|
||||
AnnotatedConvBnReLUModel,
|
||||
|
|
@ -30,6 +34,17 @@ from torch.testing._internal.common_quantization import (
|
|||
)
|
||||
from torch.testing._internal.common_quantized import override_qengines
|
||||
|
||||
class TestTwoLayerQATModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(TestTwoLayerQATModel, self).__init__()
|
||||
self.quant = QuantStub()
|
||||
self.linear1 = nn.Linear(5, 10)
|
||||
self.linear2 = nn.Linear(10, 20)
|
||||
self.dequant = DeQuantStub()
|
||||
|
||||
def forward(self, x):
|
||||
return self.dequant(self.linear2(self.linear1(self.quant(x))))
|
||||
|
||||
|
||||
class SubModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -86,6 +101,24 @@ class ModelWithFunctionals(torch.nn.Module):
|
|||
|
||||
|
||||
class TestEagerModeNumericSuite(QuantizationTestCase):
|
||||
@override_qengines
|
||||
def test_prepare_qat_stubs(self):
|
||||
qengine = torch.backends.quantized.engine
|
||||
m = TestTwoLayerQATModel()
|
||||
x = torch.rand(3, 5)
|
||||
qconfig = get_default_qat_qconfig(qengine)
|
||||
m.qconfig = qconfig
|
||||
print(m)
|
||||
prepared_m = prepare_qat(m)
|
||||
print(prepared_m)
|
||||
prepared_m(x)
|
||||
prepare_qat_model_with_stubs(prepared_m, module_swap_list=[nn.qat.Linear])
|
||||
print(prepared_m)
|
||||
prepared_m(x)
|
||||
ob_dict = get_logger_dict(prepared_m)
|
||||
print(ob_dict)
|
||||
return
|
||||
|
||||
@override_qengines
|
||||
def test_compare_weights_conv_static(self):
|
||||
r"""Compare the weights of float and static quantized conv layer"""
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import torch.nn.quantized as nnq
|
|||
import torch.nn.quantized.dynamic as nnqd
|
||||
from torch.quantization import prepare
|
||||
from typing import Dict, List, Optional, Any, Union, Callable, Set
|
||||
import torch.quantization as tq
|
||||
|
||||
from .quantization_mappings import (
|
||||
get_default_compare_output_module_list,
|
||||
|
|
@ -206,6 +207,17 @@ class OutputLogger(Logger):
|
|||
self.stats["tensor_val"].append(x)
|
||||
return x
|
||||
|
||||
class ErrorLogger(Logger):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x ,y):
|
||||
|
||||
print('Error', torch.norm(x-y))
|
||||
self.stats['signal'] = torch.norm(y)
|
||||
self.stats['error'] = torch.norm(x-y)
|
||||
return
|
||||
|
||||
|
||||
def _convert_tuple_to_list(t: Any) -> Any:
|
||||
return list(_convert_tuple_to_list(x) for x in t) if type(t) is tuple else t
|
||||
|
|
@ -293,6 +305,35 @@ class Shadow(nn.Module):
|
|||
self.logger(output, shadow_output)
|
||||
return output
|
||||
|
||||
class QATShadow(nn.Module):
|
||||
r"""Shadow module attaches the float module to its matching quantized module
|
||||
as the shadow. Then it uses Logger module to process the outputs of both
|
||||
modules.
|
||||
Args:
|
||||
float_module: float QAT module that we need to shadow
|
||||
Logger: type of logger used to process the output of
|
||||
float_module with and without fake quant.
|
||||
ShadowLogger or custom loggers can be used.
|
||||
"""
|
||||
|
||||
def __init__(self, module, Logger):
|
||||
super(QATShadow, self).__init__()
|
||||
self.orig_module = module
|
||||
self.logger = Logger()
|
||||
|
||||
def forward(self, x):
|
||||
output = self.orig_module(x)
|
||||
with torch.no_grad():
|
||||
# Save original state
|
||||
self.orig_module.apply(tq.disable_fake_quant)
|
||||
self.orig_module.apply(tq.disable_observer)
|
||||
shadow_output = self.orig_module(x)
|
||||
# TODO: Restore it back, currently state is not preserved
|
||||
self.orig_module.apply(tq.enable_fake_quant)
|
||||
self.orig_module.apply(tq.enable_observer)
|
||||
self.logger(output, shadow_output)
|
||||
|
||||
return output
|
||||
|
||||
def prepare_model_with_stubs(
|
||||
float_module: nn.Module, q_module: nn.Module,
|
||||
|
|
@ -346,6 +387,70 @@ def _is_identical_module_type(mod1, mod2):
|
|||
|
||||
|
||||
|
||||
|
||||
def _logger_forward_hook(self, input, output):
|
||||
r"""Forward hook that calls logger on the input and output of fake quant
|
||||
"""
|
||||
with torch.no_grad():
|
||||
if self.fake_quant_enabled[0]:
|
||||
# Pass first input of fake quant to logger
|
||||
return self.logger(input[0], output)
|
||||
|
||||
def _register_logger_hook(module):
|
||||
assert hasattr(module, 'activation_post_process'), \
|
||||
'Expect activation_post_process attribute already attached to the module'
|
||||
assert hasattr(module, 'logger'), \
|
||||
'Expect logger attribute already attached to the module'
|
||||
return module.register_forward_hook(_logger_forward_hook)
|
||||
|
||||
|
||||
|
||||
|
||||
def _insert_logger(module, Logger):
|
||||
for name, mod in module.named_children():
|
||||
print(name)
|
||||
if isinstance(mod, tq.FakeQuantizeBase):
|
||||
mod.add_module('logger', Logger())
|
||||
# Register logger as the first entry in the hook list
|
||||
# All post forward hooks are preserved and will be executed after the logger
|
||||
handle = _register_logger_hook(mod)
|
||||
mod._forward_hooks.move_to_end(handle.id, last=False)
|
||||
else:
|
||||
_insert_logger(mod, Logger)
|
||||
|
||||
return
|
||||
|
||||
def prepare_qat_model_with_stubs(float_module, module_swap_list, Logger=ErrorLogger):
|
||||
r"""Prepare the model by attaching the float module to its matching quantized
|
||||
module as the shadow if the float module type is in module_swap_list.
|
||||
Example usage:
|
||||
prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
|
||||
q_model(data)
|
||||
ob_dict = get_logger_dict(q_model)
|
||||
Args:
|
||||
float_module: float module used to generate the q_module
|
||||
q_module: module quantized from float_module
|
||||
module_swap_list: list of float module types to attach the shadow
|
||||
Logger: type of logger to be used in shadow module to process the outputs of
|
||||
quantized module and its float shadow module
|
||||
"""
|
||||
_insert_logger(float_module, Logger)
|
||||
float_module_children = {}
|
||||
reassign = {}
|
||||
for name, mod in float_module.named_children():
|
||||
float_module_children[name] = mod
|
||||
|
||||
float_mod = float_module_children[name]
|
||||
|
||||
if type(float_mod) not in module_swap_list:
|
||||
prepare_qat_model_with_stubs(float_mod, module_swap_list, ErrorLogger)
|
||||
|
||||
if type(float_mod) in module_swap_list:
|
||||
reassign[name] = QATShadow(float_mod, Logger)
|
||||
|
||||
for key, value in reassign.items():
|
||||
float_module._modules[key] = value
|
||||
|
||||
def compare_model_stub(
|
||||
float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
|
||||
*data, logger_cls=ShadowLogger
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user