mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
torch.ao migration: numeric suite, eager and fx (#64817)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64817 This migrates `torch.quantization._numeric_suite` to `torch.ao.ns._numeric_suite`, and `torch.quantization._numeric_suite_fx` to `torch.ao.ns._numeric_suite_fx`. 1. move the files ``` HG: move eager mode hg mv caffe2/torch/quantization/_numeric_suite.py caffe2/torch/ao/ns/ HG: move fx hg mv caffe2/torch/quantization/_numeric_suite_fx.py caffe2/torch/ao/ns/ hg mv caffe2/torch/quantization/ns/* caffe2/torch/ao/ns/fx/ ``` 2. create new versions of `_numeric_suite.py` and `_numeric_suite_fx.py` with imports 3. update all FB callsites Test Plan: buck test mode/dev //caffe2/test:quantization Reviewed By: z-a-f Differential Revision: D30867538 fbshipit-source-id: 120ee830434ca490c1183a187a518eebcbbaf22c
This commit is contained in:
parent
39f2b9de2a
commit
1577c106dc
|
|
@ -5,7 +5,7 @@ from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
|||
|
||||
from torch.quantization import default_qconfig
|
||||
from torch.quantization import QuantWrapper
|
||||
import torch.quantization._numeric_suite as ns
|
||||
import torch.ao.ns._numeric_suite as ns
|
||||
|
||||
from torch.quantization._correct_bias import (
|
||||
_supported_modules,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from torch.quantization import (
|
|||
quantize,
|
||||
quantize_dynamic,
|
||||
)
|
||||
from torch.quantization._numeric_suite import (
|
||||
from torch.ao.ns._numeric_suite import (
|
||||
OutputLogger,
|
||||
Shadow,
|
||||
ShadowLogger,
|
||||
|
|
|
|||
|
|
@ -34,29 +34,29 @@ from torch.quantization.quantization_mappings import (
|
|||
from torch.testing._internal.common_quantization import NodeSpec as ns
|
||||
from torch.quantization.fx.pattern_utils import get_default_quant_patterns
|
||||
import torch.quantization.fx.quantization_patterns as qp
|
||||
from torch.quantization.ns.pattern_utils import (
|
||||
from torch.ao.ns.fx.pattern_utils import (
|
||||
get_type_a_related_to_b,
|
||||
)
|
||||
from torch.quantization.ns.graph_matcher import (
|
||||
from torch.ao.ns.fx.graph_matcher import (
|
||||
get_matching_subgraph_pairs,
|
||||
GraphMatchingException,
|
||||
)
|
||||
from torch.quantization.ns.utils import (
|
||||
from torch.ao.ns.fx.utils import (
|
||||
compute_sqnr,
|
||||
compute_normalized_l2_error,
|
||||
compute_cosine_similarity,
|
||||
)
|
||||
from torch.quantization.ns.mappings import (
|
||||
from torch.ao.ns.fx.mappings import (
|
||||
get_node_type_to_io_type_map,
|
||||
get_unmatchable_types_map,
|
||||
get_base_name_to_sets_of_related_ops,
|
||||
get_base_name_for_op,
|
||||
add_op_to_sets_of_related_ops,
|
||||
)
|
||||
from torch.quantization.ns.weight_utils import (
|
||||
from torch.ao.ns.fx.weight_utils import (
|
||||
get_op_to_type_to_weight_extraction_fn,
|
||||
)
|
||||
from torch.quantization._numeric_suite_fx import (
|
||||
from torch.ao.ns._numeric_suite_fx import (
|
||||
extract_weights,
|
||||
_extract_weights_impl,
|
||||
add_loggers,
|
||||
|
|
@ -1634,7 +1634,7 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
|||
op_to_type_to_weight_extraction_fn = \
|
||||
get_op_to_type_to_weight_extraction_fn()
|
||||
op_to_type_to_weight_extraction_fn['call_function'][_wrapped_linear] = \
|
||||
torch.quantization.ns.weight_utils.get_linear_fun_weight
|
||||
torch.ao.ns.fx.weight_utils.get_linear_fun_weight
|
||||
|
||||
# test compare weights
|
||||
results = extract_weights(
|
||||
|
|
|
|||
486
torch/ao/ns/_numeric_suite.py
Normal file
486
torch/ao/ns/_numeric_suite.py
Normal file
|
|
@ -0,0 +1,486 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.quantized as nnq
|
||||
import torch.nn.quantized.dynamic as nnqd
|
||||
from torch.quantization import prepare
|
||||
from typing import Dict, List, Optional, Any, Union, Callable, Set
|
||||
|
||||
from torch.quantization.quantization_mappings import (
|
||||
get_default_compare_output_module_list,
|
||||
)
|
||||
|
||||
NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
|
||||
nnqd.Linear,
|
||||
nnq.Linear,
|
||||
nnqd.LSTM,
|
||||
nn.LSTM,
|
||||
}
|
||||
|
||||
|
||||
def _find_match(
|
||||
str_list: Union[Dict[str, Any], List[str]], key_str: str,
|
||||
postfix: str,
|
||||
) -> Optional[str]:
|
||||
split_str = key_str.split(".")
|
||||
if split_str[-1] == postfix:
|
||||
match_string = "".join(key_str.split(".")[0:-1])
|
||||
for s2 in str_list:
|
||||
pattern1 = "".join(s2.split(".")[0:-1])
|
||||
pattern2 = "".join(s2.split(".")[0:-2])
|
||||
if match_string == pattern1:
|
||||
return s2
|
||||
if match_string == pattern2:
|
||||
return s2
|
||||
|
||||
# For matching "fc.weight" and "fc._packed_params._packed_params"
|
||||
if postfix == "_packed_params":
|
||||
match_string = "".join(key_str.split(".")[0:-2])
|
||||
if len(match_string) == 0:
|
||||
return None
|
||||
for s2 in str_list:
|
||||
pattern1 = "".join(s2.split(".")[0:-1])
|
||||
pattern2 = "".join(s2.split(".")[0:-2])
|
||||
if match_string == pattern1:
|
||||
return s2
|
||||
if match_string == pattern2:
|
||||
return s2
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def compare_weights(
|
||||
float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
|
||||
) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
r"""Compare the weights of the float module with its corresponding quantized
|
||||
module. Return a dict with key corresponding to module names and each entry being
|
||||
a dictionary with two keys 'float' and 'quantized', containing the float and
|
||||
quantized weights. This dict can be used to compare and compute the quantization
|
||||
error of the weights of float and quantized models.
|
||||
|
||||
Example usage:
|
||||
wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict())
|
||||
for key in wt_compare_dict:
|
||||
print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))
|
||||
|
||||
Args:
|
||||
float_dict: state dict of the float model
|
||||
quantized_dict: state dict of the quantized model
|
||||
|
||||
Return:
|
||||
weight_dict: dict with key corresponding to module names and each entry being
|
||||
a dictionary with two keys 'float' and 'quantized', containing the float and
|
||||
quantized weights
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights")
|
||||
weight_dict: Dict[str, Dict] = {}
|
||||
for key in quantized_dict:
|
||||
match_key = _find_match(float_dict, key, "weight")
|
||||
if match_key is not None:
|
||||
weight_dict[key] = {}
|
||||
weight_dict[key]["float"] = float_dict[match_key]
|
||||
weight_dict[key]["quantized"] = quantized_dict[key]
|
||||
continue
|
||||
|
||||
# For matching "fc.weight" and "fc._packed_params._packed_params"
|
||||
match_key = _find_match(float_dict, key, "_packed_params")
|
||||
if match_key is not None:
|
||||
weight_dict[key] = {}
|
||||
weight_dict[key]["float"] = float_dict[match_key]
|
||||
weight_dict[key]["quantized"] = quantized_dict[key][0]
|
||||
|
||||
# For LSTM
|
||||
split_str = key.split(".")
|
||||
if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
|
||||
layer = split_str[-2]
|
||||
module_name = ".".join(split_str[:-3])
|
||||
float_weight_ih_key = module_name + ".weight_ih_l" + layer
|
||||
float_weight_hh_key = module_name + ".weight_hh_l" + layer
|
||||
if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
|
||||
weight_dict[key] = {}
|
||||
weight_dict[key]["float"] = float_dict[float_weight_ih_key]
|
||||
weight_dict[key]["quantized"] = (
|
||||
quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
|
||||
)
|
||||
weight_dict[key]["float"] = float_dict[float_weight_hh_key]
|
||||
weight_dict[key]["quantized"] = (
|
||||
quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
|
||||
)
|
||||
|
||||
return weight_dict
|
||||
|
||||
|
||||
def _get_logger_dict_helper(
|
||||
mod: nn.Module, target_dict: Dict[str, Any],
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
r"""This is the helper function for get_logger_dict
|
||||
|
||||
Args:
|
||||
mod: module we want to save all logger stats
|
||||
prefix: prefix for the current module
|
||||
target_dict: the dictionary used to save all logger stats
|
||||
"""
|
||||
|
||||
def get_prefix(prefix):
|
||||
return prefix if prefix == "" else prefix + "."
|
||||
|
||||
for name, child in mod.named_children():
|
||||
if isinstance(child, Logger):
|
||||
target_dict[get_prefix(prefix) + "stats"] = child.stats
|
||||
break
|
||||
|
||||
for name, child in mod.named_children():
|
||||
module_prefix = get_prefix(prefix) + name if prefix else name
|
||||
_get_logger_dict_helper(child, target_dict, module_prefix)
|
||||
|
||||
|
||||
def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
|
||||
r"""Traverse the modules and save all logger stats into target dict.
|
||||
This is mainly used for quantization accuracy debug.
|
||||
|
||||
Type of loggers supported:
|
||||
ShadowLogger: used to log the outputs of the quantized module and its
|
||||
matching float shadow module,
|
||||
OutputLogger: used to log the outputs of the modules
|
||||
|
||||
Args:
|
||||
mod: module we want to save all logger stats
|
||||
prefix: prefix for the current module
|
||||
|
||||
Return:
|
||||
target_dict: the dictionary used to save all logger stats
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict")
|
||||
|
||||
target_dict: Dict[str, Dict] = {}
|
||||
_get_logger_dict_helper(mod, target_dict, prefix)
|
||||
return target_dict
|
||||
|
||||
|
||||
class Logger(nn.Module):
|
||||
r"""Base class for stats logging
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Logger, self).__init__()
|
||||
self.stats = {}
|
||||
# We only insert observer if the op is quantized with static quantization,
|
||||
# which is identified by activation_observer.dtype == quint8. This is needed
|
||||
# when attaching Logger as observer for FX mode
|
||||
self.dtype = torch.quint8
|
||||
|
||||
def forward(self, x):
|
||||
pass
|
||||
|
||||
|
||||
class ShadowLogger(Logger):
|
||||
r"""Class used in Shadow module to record the outputs of the original and
|
||||
shadow modules.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ShadowLogger, self).__init__()
|
||||
self.stats["float"] = []
|
||||
self.stats["quantized"] = []
|
||||
|
||||
def forward(self, x, y):
|
||||
if len(x) > 1:
|
||||
x = x[0]
|
||||
if len(y) > 1:
|
||||
y = y[0]
|
||||
self.stats["quantized"].append(x.detach())
|
||||
self.stats["float"].append(y.detach())
|
||||
|
||||
|
||||
class OutputLogger(Logger):
|
||||
r"""Class used to log the outputs of the module
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(OutputLogger, self).__init__()
|
||||
self.stats["tensor_val"] = []
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
self.stats["tensor_val"].append(x)
|
||||
return x
|
||||
|
||||
|
||||
def _convert_tuple_to_list(t: Any) -> Any:
|
||||
return list(_convert_tuple_to_list(x) for x in t) if type(t) is tuple else t
|
||||
|
||||
|
||||
def _dequantize_tensor_list(t: Any) -> Any:
|
||||
return (
|
||||
list(_dequantize_tensor_list(x) for x in t)
|
||||
if type(t) is list
|
||||
else t.dequantize()
|
||||
if t.is_quantized
|
||||
else t
|
||||
)
|
||||
|
||||
|
||||
class Shadow(nn.Module):
|
||||
r"""Shadow module attaches the float module to its matching quantized module
|
||||
as the shadow. Then it uses Logger module to process the outputs of both
|
||||
modules.
|
||||
|
||||
Args:
|
||||
q_module: module quantized from float_module that we want to shadow
|
||||
float_module: float module used to shadow q_module
|
||||
logger_cls: type of logger used to process the outputs of q_module and
|
||||
float_module. ShadowLogger or custom loggers can be used.
|
||||
"""
|
||||
|
||||
def __init__(self, q_module, float_module, logger_cls):
|
||||
super(Shadow, self).__init__()
|
||||
self.orig_module = q_module
|
||||
self.shadow_module = float_module
|
||||
self.dequant = nnq.DeQuantize()
|
||||
self.logger = logger_cls()
|
||||
|
||||
def forward(self, *x) -> torch.Tensor:
|
||||
xl = _convert_tuple_to_list(x)
|
||||
output = self.orig_module(*xl)
|
||||
xl_float = _dequantize_tensor_list(xl)
|
||||
shadow_output = self.shadow_module(*xl_float)
|
||||
self.logger(output, shadow_output)
|
||||
return output
|
||||
|
||||
def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
output = self.orig_module.add(x, y)
|
||||
x = x.dequantize()
|
||||
y = y.dequantize()
|
||||
shadow_output = self.shadow_module.add(x, y)
|
||||
self.logger(output, shadow_output)
|
||||
return output
|
||||
|
||||
def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
|
||||
output = self.orig_module.add_scalar(x, y)
|
||||
x = x.dequantize()
|
||||
shadow_output = self.shadow_module.add_scalar(x, y)
|
||||
self.logger(output, shadow_output)
|
||||
return output
|
||||
|
||||
def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
output = self.orig_module.mul(x, y)
|
||||
x = x.dequantize()
|
||||
y = y.dequantize()
|
||||
shadow_output = self.shadow_module.mul(x, y)
|
||||
self.logger(output, shadow_output)
|
||||
return output
|
||||
|
||||
def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
|
||||
output = self.orig_module.mul_scalar(x, y)
|
||||
x = x.dequantize()
|
||||
shadow_output = self.shadow_module.mul_scalar(x, y)
|
||||
self.logger(output, shadow_output)
|
||||
return output
|
||||
|
||||
def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
|
||||
output = self.orig_module.cat(x, dim)
|
||||
x = [y.dequantize() for y in x]
|
||||
shadow_output = self.shadow_module.cat(x, dim)
|
||||
self.logger(output, shadow_output)
|
||||
return output
|
||||
|
||||
def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
output = self.orig_module.add_relu(x, y)
|
||||
x = x.dequantize()
|
||||
y = y.dequantize()
|
||||
shadow_output = self.shadow_module.add_relu(x, y)
|
||||
self.logger(output, shadow_output)
|
||||
return output
|
||||
|
||||
|
||||
def prepare_model_with_stubs(
|
||||
float_module: nn.Module, q_module: nn.Module,
|
||||
module_swap_list: Set[type], logger_cls: Callable,
|
||||
) -> None:
|
||||
r"""Prepare the model by attaching the float module to its matching quantized
|
||||
module as the shadow if the float module type is in module_swap_list.
|
||||
|
||||
Example usage:
|
||||
prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
|
||||
q_model(data)
|
||||
ob_dict = get_logger_dict(q_model)
|
||||
|
||||
Args:
|
||||
float_module: float module used to generate the q_module
|
||||
q_module: module quantized from float_module
|
||||
module_swap_list: list of float module types to attach the shadow
|
||||
logger_cls: type of logger to be used in shadow module to process the outputs of
|
||||
quantized module and its float shadow module
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_with_stubs")
|
||||
|
||||
float_module_children = {}
|
||||
for name, mod in float_module.named_children():
|
||||
float_module_children[name] = mod
|
||||
|
||||
reassign = {}
|
||||
for name, mod in q_module.named_children():
|
||||
|
||||
if name not in float_module_children:
|
||||
continue
|
||||
|
||||
float_mod = float_module_children[name]
|
||||
|
||||
if type(float_mod) not in module_swap_list:
|
||||
prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls)
|
||||
|
||||
# Insert shadow module only if the module is not of the same type as
|
||||
# the floating point module
|
||||
if type(float_mod) in module_swap_list and not _is_identical_module_type(mod, float_mod):
|
||||
reassign[name] = Shadow(mod, float_mod, logger_cls)
|
||||
|
||||
for key, value in reassign.items():
|
||||
q_module._modules[key] = value
|
||||
|
||||
def _is_identical_module_type(mod1, mod2):
|
||||
# Compare if two modules have the same dtype
|
||||
mod1_module_types = [type(mod) for mod in mod1.modules()]
|
||||
mod2_module_types = [type(mod) for mod in mod2.modules()]
|
||||
return mod1_module_types == mod2_module_types
|
||||
|
||||
|
||||
|
||||
def compare_model_stub(
|
||||
float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
|
||||
*data, logger_cls=ShadowLogger
|
||||
) -> Dict[str, Dict]:
|
||||
r"""Compare quantized module in a model with its floating point counterpart,
|
||||
feeding both of them the same input. Return a dict with key corresponding to
|
||||
module names and each entry being a dictionary with two keys 'float' and
|
||||
'quantized', containing the output tensors of quantized and its matching
|
||||
float shadow module. This dict can be used to compare and compute the module
|
||||
level quantization error.
|
||||
|
||||
This function first call prepare_model_with_stubs() to swap the quantized
|
||||
module that we want to compare with the Shadow module, which takes quantized
|
||||
module, corresponding float module and logger as input, and creates a forward
|
||||
path inside to make the float module to shadow quantized module sharing the
|
||||
same input. The logger can be customizable, default logger is ShadowLogger
|
||||
and it will save the outputs of the quantized module and float module that
|
||||
can be used to compute the module level quantization error.
|
||||
|
||||
Example usage:
|
||||
module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
|
||||
ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
|
||||
for key in ob_dict:
|
||||
print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
|
||||
|
||||
Args:
|
||||
float_model: float model used to generate the q_model
|
||||
q_model: model quantized from float_model
|
||||
module_swap_list: list of float module types at which shadow modules will
|
||||
be attached.
|
||||
data: input data used to run the prepared q_model
|
||||
logger_cls: type of logger to be used in shadow module to process the outputs of
|
||||
quantized module and its float shadow module
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub")
|
||||
prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls)
|
||||
q_model(*data)
|
||||
ob_dict = get_logger_dict(q_model)
|
||||
return ob_dict
|
||||
|
||||
|
||||
def get_matching_activations(
|
||||
float_module: nn.Module, q_module: nn.Module,
|
||||
) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
r"""Find the matching activation between float and quantized modules.
|
||||
|
||||
Args:
|
||||
float_module: float module used to generate the q_module
|
||||
q_module: module quantized from float_module
|
||||
|
||||
Return:
|
||||
act_dict: dict with key corresponding to quantized module names and each
|
||||
entry being a dictionary with two keys 'float' and 'quantized', containing
|
||||
the matching float and quantized activations
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.get_matching_activations")
|
||||
float_dict = get_logger_dict(float_module)
|
||||
quantized_dict = get_logger_dict(q_module)
|
||||
act_dict: Dict[str, Dict] = {}
|
||||
for key in quantized_dict:
|
||||
match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
|
||||
if match_key is not None:
|
||||
act_dict[key] = {}
|
||||
act_dict[key]["float"] = float_dict[match_key]["tensor_val"]
|
||||
act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"]
|
||||
return act_dict
|
||||
|
||||
|
||||
def prepare_model_outputs(
|
||||
float_module: nn.Module,
|
||||
q_module: nn.Module,
|
||||
logger_cls=OutputLogger,
|
||||
allow_list=None
|
||||
) -> None:
|
||||
r"""Prepare the model by attaching the logger to both float module
|
||||
and quantized module if they are in the allow_list.
|
||||
|
||||
Args:
|
||||
float_module: float module used to generate the q_module
|
||||
q_module: module quantized from float_module
|
||||
logger_cls: type of logger to be attached to float_module and q_module
|
||||
allow_list: list of module types to attach logger
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs")
|
||||
if allow_list is None:
|
||||
allow_list = get_default_compare_output_module_list()
|
||||
|
||||
qconfig_debug = torch.quantization.QConfig(activation=logger_cls, weight=None)
|
||||
float_module.qconfig = qconfig_debug # type: ignore[assignment]
|
||||
prepare(float_module, inplace=True, allow_list=allow_list)
|
||||
q_module.qconfig = qconfig_debug # type: ignore[assignment]
|
||||
prepare(
|
||||
q_module,
|
||||
inplace=True,
|
||||
allow_list=allow_list,
|
||||
observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
|
||||
)
|
||||
|
||||
|
||||
def compare_model_outputs(
|
||||
float_model: nn.Module,
|
||||
q_model: nn.Module,
|
||||
*data,
|
||||
logger_cls=OutputLogger,
|
||||
allow_list=None
|
||||
) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
r"""Compare output activations between float and quantized models at
|
||||
corresponding locations for the same input. Return a dict with key corresponding
|
||||
to quantized module names and each entry being a dictionary with two keys
|
||||
'float' and 'quantized', containing the activations of quantized model and
|
||||
float model at matching locations. This dict can be used to compare and
|
||||
compute the propagation quantization error.
|
||||
|
||||
Example usage:
|
||||
act_compare_dict = compare_model_outputs(float_model, qmodel, data)
|
||||
for key in act_compare_dict:
|
||||
print(key, compute_error(act_compare_dict[key]['float'], act_compare_dict[key]['quantized'].dequantize()))
|
||||
|
||||
Args:
|
||||
float_model: float model used to generate the q_model
|
||||
q_model: model quantized from float_model
|
||||
data: input data used to run the prepared float_model and q_model
|
||||
logger_cls: type of logger to be attached to float_module and q_module
|
||||
allow_list: list of module types to attach logger
|
||||
|
||||
Return:
|
||||
act_compare_dict: dict with key corresponding to quantized module names
|
||||
and each entry being a dictionary with two keys 'float' and 'quantized',
|
||||
containing the matching float and quantized activations
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs")
|
||||
if allow_list is None:
|
||||
allow_list = get_default_compare_output_module_list()
|
||||
prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
|
||||
float_model(*data)
|
||||
q_model(*data)
|
||||
act_compare_dict = get_matching_activations(float_model, q_model)
|
||||
return act_compare_dict
|
||||
513
torch/ao/ns/_numeric_suite_fx.py
Normal file
513
torch/ao/ns/_numeric_suite_fx.py
Normal file
|
|
@ -0,0 +1,513 @@
|
|||
import collections
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.quantization.quantize_fx as quantize_fx
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.graph import Node
|
||||
from torch.ao.ns.fx.mappings import (
|
||||
get_base_name_to_sets_of_related_ops,
|
||||
)
|
||||
from torch.ao.ns.fx.graph_matcher import (
|
||||
get_matching_subgraph_pairs,
|
||||
get_type_a_related_to_b,
|
||||
)
|
||||
|
||||
from .fx.weight_utils import (
|
||||
extract_weight_from_node,
|
||||
)
|
||||
|
||||
from .fx.graph_passes import (
|
||||
add_loggers_to_model,
|
||||
create_a_shadows_b,
|
||||
)
|
||||
|
||||
from .fx.utils import (
|
||||
rekey_logger_info_on_node_name_of_model,
|
||||
maybe_add_missing_fqns,
|
||||
get_target_type_str,
|
||||
)
|
||||
|
||||
from .fx.ns_types import (
|
||||
NSSingleResultValuesType,
|
||||
NSResultsType,
|
||||
NSNodeTargetType,
|
||||
)
|
||||
|
||||
from typing import Dict, Tuple, Callable, List, Optional, Set
|
||||
|
||||
RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
||||
|
||||
class OutputLogger(nn.Module):
|
||||
stats: List[torch.Tensor]
|
||||
stats_rnn: List[RNNReturnType]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ref_node_name: str,
|
||||
prev_node_name: str,
|
||||
model_name: str,
|
||||
ref_name: str,
|
||||
prev_node_target_type: str,
|
||||
ref_node_target_type: str,
|
||||
results_type: str,
|
||||
index_within_arg: int,
|
||||
index_of_arg: int,
|
||||
fqn: Optional[str],
|
||||
):
|
||||
super().__init__()
|
||||
self.stats: List[torch.Tensor] = []
|
||||
self.stats_rnn: List[RNNReturnType] = []
|
||||
|
||||
# name of the node which was responsible for adding this logger
|
||||
# Note:
|
||||
# - if we are logging node outputs, this is the same as prev_node_name
|
||||
# - if we are logging node inputs, this is the name of the node
|
||||
# whose input this logger is logging.
|
||||
#
|
||||
# example, where logger1 is logging input of op1 and logger2 is logging
|
||||
# the output of op1:
|
||||
#
|
||||
# x1 -> logger1 -> op1 -> logger2 -> x2
|
||||
#
|
||||
# in this example,
|
||||
# - logger1's prev_node_name is x1 and ref_node_name is op1
|
||||
# - logger2's prev_node_name is op1 and ref_node_name is op1
|
||||
self.ref_node_name = ref_node_name
|
||||
# name of the node whose output this Logger is capturing
|
||||
self.prev_node_name = prev_node_name
|
||||
|
||||
# name of the model from which the node originated from
|
||||
self.model_name = model_name
|
||||
# reference name, used to match loggers from separate models
|
||||
# to each other
|
||||
self.ref_name = ref_name
|
||||
# type of the target of the node whose output this logger is logging
|
||||
self.prev_node_target_type = prev_node_target_type
|
||||
# type of the target of the node which was respondible for adding this
|
||||
# logger
|
||||
self.ref_node_target_type = ref_node_target_type
|
||||
# what kind of values are inside of stats
|
||||
self.results_type = results_type
|
||||
# index of this node within the arg of the input/output node
|
||||
# for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
|
||||
self.index_within_arg = index_within_arg
|
||||
# index of this node within the args of the input/output node
|
||||
# for example, in add(x1, x2), x2 would have index_of_arg == 1
|
||||
self.index_of_arg = index_of_arg
|
||||
# fully qualified name
|
||||
self.fqn = fqn
|
||||
|
||||
# Note: cannot annotate the type of x because TorchScript does not support
|
||||
# the Union type.
|
||||
def forward(self, x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
self.stats.append(x.detach())
|
||||
elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
|
||||
new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
|
||||
self.stats_rnn.append(new_res)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"""OutputLogger(ref_name={self.ref_name}, model_name={self.model_name},
|
||||
prev_node_name={self.prev_node_name}, ref_node_name={self.ref_node_name},
|
||||
ref_node_target_type={self.ref_node_target_type}
|
||||
results_type={self.results_type}, index_within_arg={self.index_within_arg},
|
||||
index_of_arg={self.index_of_arg}, fqn={self.fqn})"""
|
||||
|
||||
|
||||
class NSTracer(quantize_fx.QuantizationTracer):
|
||||
"""
|
||||
Just like a regular tracer, but treats observers and fake_quantize
|
||||
modules as leaf modules.
|
||||
"""
|
||||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
||||
if isinstance(m, torch.quantization.ObserverBase):
|
||||
return True
|
||||
elif isinstance(m, torch.quantization.FakeQuantizeBase):
|
||||
return True
|
||||
return super().is_leaf_module(m, module_qualified_name)
|
||||
|
||||
|
||||
def _extract_weights_one_model(
|
||||
model_name: str,
|
||||
model: GraphModule,
|
||||
nodes_and_names_to_instrument: List[Tuple[Node, str]],
|
||||
results: NSResultsType,
|
||||
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
||||
) -> None:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
|
||||
for node, ref_name in nodes_and_names_to_instrument:
|
||||
res_type = NSSingleResultValuesType.WEIGHT.value
|
||||
extracted_weight = extract_weight_from_node(
|
||||
node, model, op_to_type_to_weight_extraction_fn)
|
||||
if extracted_weight:
|
||||
if ref_name not in results:
|
||||
results[ref_name] = {res_type: {}}
|
||||
results[ref_name][res_type][model_name] = [extracted_weight]
|
||||
|
||||
|
||||
def _extract_weights_impl(
|
||||
model_name_a: str,
|
||||
gm_a: GraphModule,
|
||||
model_name_b: str,
|
||||
gm_b: GraphModule,
|
||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
||||
) -> NSResultsType:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
|
||||
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
||||
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
||||
unmatchable_types_map)
|
||||
|
||||
# split the subgraph pairs into one data structure for each model
|
||||
nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
|
||||
nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
|
||||
for match_name, match in matched_subgraph_pairs.items():
|
||||
subgraph_a, subgraph_b = match
|
||||
nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
|
||||
nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
|
||||
|
||||
# populate the results, one model at a time
|
||||
results: NSResultsType = {}
|
||||
_extract_weights_one_model(
|
||||
model_name_a, gm_a, nodes_and_names_to_instrument_a, results,
|
||||
op_to_type_to_weight_extraction_fn)
|
||||
_extract_weights_one_model(
|
||||
model_name_b, gm_b, nodes_and_names_to_instrument_b, results,
|
||||
op_to_type_to_weight_extraction_fn)
|
||||
|
||||
# fill in missing fqn entries
|
||||
maybe_add_missing_fqns(results)
|
||||
|
||||
# rekey on names of nodes in gm_b
|
||||
results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def extract_weights(
|
||||
model_name_a: str,
|
||||
model_a: nn.Module,
|
||||
model_name_b: str,
|
||||
model_b: nn.Module,
|
||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
||||
) -> NSResultsType:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
|
||||
if base_name_to_sets_of_related_ops is None:
|
||||
base_name_to_sets_of_related_ops = \
|
||||
get_base_name_to_sets_of_related_ops()
|
||||
type_a_related_to_b = \
|
||||
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
||||
|
||||
# TODO(future PR): expose these
|
||||
skipped_module_names: List[str] = []
|
||||
skipped_module_classes: List[Callable] = []
|
||||
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
||||
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
||||
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
||||
if hasattr(model_a, '_node_name_to_scope'):
|
||||
gm_a._node_name_to_scope = model_a._node_name_to_scope
|
||||
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
||||
if hasattr(model_b, '_node_name_to_scope'):
|
||||
gm_b._node_name_to_scope = model_b._node_name_to_scope
|
||||
return _extract_weights_impl(
|
||||
model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
|
||||
unmatchable_types_map, op_to_type_to_weight_extraction_fn)
|
||||
|
||||
|
||||
def _add_loggers_one_model(
|
||||
model_name: str,
|
||||
model: GraphModule,
|
||||
nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]],
|
||||
nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]],
|
||||
logger_cls: Callable,
|
||||
) -> nn.Module:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
|
||||
|
||||
# TODO(future PR): do not observe nodes we do not care
|
||||
# about (both fp32, denylist, etc)
|
||||
node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
|
||||
node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
|
||||
for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs:
|
||||
node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type)
|
||||
for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs:
|
||||
node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type)
|
||||
|
||||
model = add_loggers_to_model(
|
||||
model, node_to_instrument_inputs_to_ref_name,
|
||||
node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
|
||||
return model
|
||||
|
||||
|
||||
def _add_loggers_impl(
|
||||
name_a: str,
|
||||
gm_a: GraphModule,
|
||||
name_b: str,
|
||||
gm_b: GraphModule,
|
||||
logger_cls: Callable,
|
||||
should_log_inputs: bool,
|
||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
) -> Tuple[nn.Module, nn.Module]:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
|
||||
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
||||
gm_a, gm_b,
|
||||
base_name_to_sets_of_related_ops, unmatchable_types_map)
|
||||
nodes_and_names_to_instrument_inputs_a = []
|
||||
nodes_and_names_to_instrument_inputs_b = []
|
||||
nodes_and_names_to_instrument_outputs_a = []
|
||||
nodes_and_names_to_instrument_outputs_b = []
|
||||
for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
|
||||
ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
|
||||
ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
|
||||
# Note: for matching inputs we use start_node, such as observing
|
||||
# the input of linear in linear-relu
|
||||
if should_log_inputs:
|
||||
nodes_and_names_to_instrument_inputs_a.append(
|
||||
(subgraph_a.start_node, match_name, ref_node_type_a))
|
||||
nodes_and_names_to_instrument_inputs_b.append(
|
||||
(subgraph_b.start_node, match_name, ref_node_type_b))
|
||||
# Note: for matching activations we always use end_node,
|
||||
# such as observing the output of relu in linear-relu
|
||||
nodes_and_names_to_instrument_outputs_a.append(
|
||||
(subgraph_a.end_node, match_name, ref_node_type_a))
|
||||
nodes_and_names_to_instrument_outputs_b.append(
|
||||
(subgraph_b.end_node, match_name, ref_node_type_b))
|
||||
|
||||
new_model_a = _add_loggers_one_model(
|
||||
name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
|
||||
nodes_and_names_to_instrument_outputs_a, logger_cls)
|
||||
new_model_b = _add_loggers_one_model(
|
||||
name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
|
||||
nodes_and_names_to_instrument_outputs_b, logger_cls)
|
||||
return (new_model_a, new_model_b)
|
||||
|
||||
|
||||
def add_loggers(
|
||||
name_a: str,
|
||||
model_a: nn.Module,
|
||||
name_b: str,
|
||||
model_b: nn.Module,
|
||||
logger_cls: Callable,
|
||||
should_log_inputs : bool = False,
|
||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
) -> Tuple[nn.Module, nn.Module]:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
|
||||
# TODO(future PR): expose these
|
||||
skipped_module_names: List[str] = []
|
||||
skipped_module_classes: List[Callable] = []
|
||||
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
||||
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
||||
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
||||
if hasattr(model_a, '_node_name_to_scope'):
|
||||
gm_a._node_name_to_scope = model_a._node_name_to_scope
|
||||
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
||||
if hasattr(model_b, '_node_name_to_scope'):
|
||||
gm_b._node_name_to_scope = model_b._node_name_to_scope
|
||||
return _add_loggers_impl(
|
||||
name_a, gm_a, name_b, gm_b, logger_cls,
|
||||
should_log_inputs=should_log_inputs,
|
||||
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
||||
unmatchable_types_map=unmatchable_types_map)
|
||||
|
||||
|
||||
def _extract_logger_info_one_model(
|
||||
model: nn.Module,
|
||||
results: NSResultsType,
|
||||
logger_cls: Callable,
|
||||
) -> None:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
|
||||
for gm_name, mod in model.named_modules():
|
||||
# TODO(future PR): better check when scripted
|
||||
is_logger = (
|
||||
isinstance(mod, logger_cls) # type: ignore[arg-type]
|
||||
or (
|
||||
isinstance(mod, torch.jit.RecursiveScriptModule)
|
||||
and mod.original_name == 'OutputLogger'
|
||||
)
|
||||
)
|
||||
if is_logger:
|
||||
key = mod.ref_name
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
assert mod.model_name not in results[key], \
|
||||
f"{mod.model_name} is already present in results"
|
||||
if mod.results_type not in results[key]:
|
||||
results[key][mod.results_type] = {}
|
||||
if mod.model_name not in results[key][mod.results_type]:
|
||||
results[key][mod.results_type][mod.model_name] = []
|
||||
stats_to_use = mod.stats
|
||||
if len(mod.stats_rnn) > 0:
|
||||
stats_to_use = mod.stats_rnn
|
||||
results[key][mod.results_type][mod.model_name].append({
|
||||
'type': mod.results_type,
|
||||
'values': stats_to_use,
|
||||
'ref_node_name': mod.ref_node_name,
|
||||
'ref_node_target_type': mod.ref_node_target_type,
|
||||
'prev_node_name': mod.prev_node_name,
|
||||
'prev_node_target_type': mod.prev_node_target_type,
|
||||
'index_within_arg': mod.index_within_arg,
|
||||
'index_of_arg': mod.index_of_arg,
|
||||
'fqn': mod.fqn,
|
||||
})
|
||||
# ensure the list stays sorted
|
||||
results[key][mod.results_type][mod.model_name].sort(
|
||||
key=lambda res:
|
||||
f"{res['index_of_arg']}:{res['index_within_arg']}"
|
||||
)
|
||||
|
||||
|
||||
# TODO(future PR): align on naming
|
||||
# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
|
||||
def extract_logger_info(
|
||||
model_a: nn.Module,
|
||||
model_b: nn.Module,
|
||||
logger_cls: Callable,
|
||||
model_name_to_use_for_layer_names: str,
|
||||
) -> NSResultsType:
|
||||
"""
|
||||
Same thing as ns.extract_logger_info, but for models prepared with
|
||||
this module.
|
||||
|
||||
TODO(future PR): real docblock
|
||||
|
||||
Output format: NSResultsType
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
|
||||
results: NSResultsType = {}
|
||||
for model in (model_a, model_b):
|
||||
_extract_logger_info_one_model(model, results, logger_cls)
|
||||
# fill in missing fqn entries
|
||||
maybe_add_missing_fqns(results)
|
||||
# rekey on the name of model b
|
||||
results = rekey_logger_info_on_node_name_of_model(
|
||||
results, model_name_to_use_for_layer_names)
|
||||
return results
|
||||
|
||||
|
||||
def _add_shadow_loggers_impl(
|
||||
name_a: str,
|
||||
gm_a: GraphModule,
|
||||
name_b: str,
|
||||
gm_b: GraphModule,
|
||||
logger_cls: Callable,
|
||||
should_log_inputs: bool,
|
||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
) -> nn.Module:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
|
||||
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
||||
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
||||
unmatchable_types_map)
|
||||
gm_a_shadows_b = create_a_shadows_b(
|
||||
name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
|
||||
should_log_inputs=should_log_inputs,
|
||||
node_type_to_io_type_map=node_type_to_io_type_map)
|
||||
return gm_a_shadows_b
|
||||
|
||||
|
||||
def add_shadow_loggers(
|
||||
name_a: str,
|
||||
model_a: nn.Module,
|
||||
name_b: str,
|
||||
model_b: nn.Module,
|
||||
logger_cls: Callable,
|
||||
should_log_inputs: bool = False,
|
||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Same thing as add_loggers, but for an `a_shadows_b` model.
|
||||
TODO(future PR): real docblock
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
|
||||
# TODO(future PR): expose these
|
||||
skipped_module_names: List[str] = []
|
||||
skipped_module_classes: List[Callable] = []
|
||||
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
||||
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
||||
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
||||
if hasattr(model_a, '_node_name_to_scope'):
|
||||
gm_a._node_name_to_scope = model_a._node_name_to_scope
|
||||
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
||||
if hasattr(model_b, '_node_name_to_scope'):
|
||||
gm_b._node_name_to_scope = model_b._node_name_to_scope
|
||||
return _add_shadow_loggers_impl(
|
||||
name_a, gm_a, name_b, gm_b, logger_cls,
|
||||
should_log_inputs=should_log_inputs,
|
||||
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
||||
node_type_to_io_type_map=node_type_to_io_type_map,
|
||||
unmatchable_types_map=unmatchable_types_map)
|
||||
|
||||
|
||||
def extract_shadow_logger_info(
|
||||
model_a_shadows_b: nn.Module,
|
||||
logger_cls: Callable,
|
||||
model_name_to_use_for_layer_names: str,
|
||||
) -> NSResultsType:
|
||||
"""
|
||||
Same thing as extract_logger_info, but for an `a_shadows_b` model.
|
||||
TODO(future PR): real docblock
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
|
||||
results: NSResultsType = collections.defaultdict(dict)
|
||||
_extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
|
||||
# fill in missing fqn entries
|
||||
maybe_add_missing_fqns(results)
|
||||
# rekey on the name of model b
|
||||
results = rekey_logger_info_on_node_name_of_model(
|
||||
results, model_name_to_use_for_layer_names)
|
||||
return dict(results)
|
||||
|
||||
|
||||
def extend_logger_results_with_comparison(
|
||||
results: NSResultsType,
|
||||
model_name_1: str,
|
||||
model_name_2: str,
|
||||
comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
comparison_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Compares the logged values from `model_name_2` against the corresponding
|
||||
values in `model_name_1`, using `comparison_fn`. Records the result
|
||||
in `model_name_2`'s results under `comparison_name`.
|
||||
"""
|
||||
for _, results_type_to_results in results.items():
|
||||
for _, model_name_to_results in results_type_to_results.items():
|
||||
assert model_name_1 in model_name_to_results, \
|
||||
f"{model_name_1} not found in results"
|
||||
assert model_name_2 in model_name_to_results, \
|
||||
f"{model_name_2} not found in results"
|
||||
|
||||
results_1 = model_name_to_results[model_name_1]
|
||||
results_2 = model_name_to_results[model_name_2]
|
||||
|
||||
for result_2 in results_2:
|
||||
index_within_arg_2 = result_2['index_within_arg']
|
||||
index_of_arg_2 = result_2['index_of_arg']
|
||||
# find corresponding result_1
|
||||
result_1 = None
|
||||
for cur_result_1 in results_1:
|
||||
index_within_arg_1 = cur_result_1['index_within_arg']
|
||||
index_of_arg_1 = cur_result_1['index_of_arg']
|
||||
if (
|
||||
(index_within_arg_1 == index_within_arg_2) and
|
||||
(index_of_arg_1 == index_of_arg_2)
|
||||
):
|
||||
result_1 = cur_result_1
|
||||
break
|
||||
assert result_1 is not None
|
||||
|
||||
values_1 = result_1['values']
|
||||
values_2 = result_2['values']
|
||||
result_2[comparison_name] = []
|
||||
for value_1, value_2 in zip(values_1, values_2):
|
||||
comparison_result = comparison_fn(value_1, value_2)
|
||||
result_2[comparison_name].append(comparison_result)
|
||||
0
torch/ao/ns/fx/__init__.py
Normal file
0
torch/ao/ns/fx/__init__.py
Normal file
|
|
@ -19,7 +19,7 @@ from .ns_types import (
|
|||
NSSubgraph,
|
||||
NSNodeTargetType,
|
||||
)
|
||||
from torch.quantization.ns.mappings import (
|
||||
from torch.ao.ns.fx.mappings import (
|
||||
get_node_type_to_io_type_map,
|
||||
)
|
||||
from torch.quantization.quantize import is_activation_post_process
|
||||
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
|||
import torch.nn.quantized as nnq
|
||||
|
||||
import torch.quantization
|
||||
import torch.quantization._numeric_suite as ns
|
||||
import torch.ao.ns._numeric_suite as ns
|
||||
|
||||
_supported_modules = {nn.Linear, nn.Conv2d}
|
||||
_supported_modules_quantized = {nnq.Linear, nnq.Conv2d}
|
||||
|
|
|
|||
|
|
@ -1,486 +1,28 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.quantized as nnq
|
||||
import torch.nn.quantized.dynamic as nnqd
|
||||
from torch.quantization import prepare
|
||||
from typing import Dict, List, Optional, Any, Union, Callable, Set
|
||||
# flake8: noqa: F401
|
||||
r"""
|
||||
This file is in the process of migration to `torch/ao/quantization`, and
|
||||
is kept here for compatibility while the migration process is ongoing.
|
||||
If you are adding a new entry/functionality, please, add it to the
|
||||
`torch/ao/ns/_numeric_suite.py`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
from .quantization_mappings import (
|
||||
get_default_compare_output_module_list,
|
||||
from torch.ao.ns._numeric_suite import (
|
||||
NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
|
||||
_find_match,
|
||||
compare_weights,
|
||||
_get_logger_dict_helper,
|
||||
get_logger_dict,
|
||||
Logger,
|
||||
ShadowLogger,
|
||||
OutputLogger,
|
||||
_convert_tuple_to_list,
|
||||
_dequantize_tensor_list,
|
||||
Shadow,
|
||||
prepare_model_with_stubs,
|
||||
_is_identical_module_type,
|
||||
compare_model_stub,
|
||||
get_matching_activations,
|
||||
prepare_model_outputs,
|
||||
compare_model_outputs,
|
||||
)
|
||||
|
||||
NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
|
||||
nnqd.Linear,
|
||||
nnq.Linear,
|
||||
nnqd.LSTM,
|
||||
nn.LSTM,
|
||||
}
|
||||
|
||||
|
||||
def _find_match(
|
||||
str_list: Union[Dict[str, Any], List[str]], key_str: str,
|
||||
postfix: str,
|
||||
) -> Optional[str]:
|
||||
split_str = key_str.split(".")
|
||||
if split_str[-1] == postfix:
|
||||
match_string = "".join(key_str.split(".")[0:-1])
|
||||
for s2 in str_list:
|
||||
pattern1 = "".join(s2.split(".")[0:-1])
|
||||
pattern2 = "".join(s2.split(".")[0:-2])
|
||||
if match_string == pattern1:
|
||||
return s2
|
||||
if match_string == pattern2:
|
||||
return s2
|
||||
|
||||
# For matching "fc.weight" and "fc._packed_params._packed_params"
|
||||
if postfix == "_packed_params":
|
||||
match_string = "".join(key_str.split(".")[0:-2])
|
||||
if len(match_string) == 0:
|
||||
return None
|
||||
for s2 in str_list:
|
||||
pattern1 = "".join(s2.split(".")[0:-1])
|
||||
pattern2 = "".join(s2.split(".")[0:-2])
|
||||
if match_string == pattern1:
|
||||
return s2
|
||||
if match_string == pattern2:
|
||||
return s2
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def compare_weights(
|
||||
float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
|
||||
) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
r"""Compare the weights of the float module with its corresponding quantized
|
||||
module. Return a dict with key corresponding to module names and each entry being
|
||||
a dictionary with two keys 'float' and 'quantized', containing the float and
|
||||
quantized weights. This dict can be used to compare and compute the quantization
|
||||
error of the weights of float and quantized models.
|
||||
|
||||
Example usage:
|
||||
wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict())
|
||||
for key in wt_compare_dict:
|
||||
print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))
|
||||
|
||||
Args:
|
||||
float_dict: state dict of the float model
|
||||
quantized_dict: state dict of the quantized model
|
||||
|
||||
Return:
|
||||
weight_dict: dict with key corresponding to module names and each entry being
|
||||
a dictionary with two keys 'float' and 'quantized', containing the float and
|
||||
quantized weights
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights")
|
||||
weight_dict: Dict[str, Dict] = {}
|
||||
for key in quantized_dict:
|
||||
match_key = _find_match(float_dict, key, "weight")
|
||||
if match_key is not None:
|
||||
weight_dict[key] = {}
|
||||
weight_dict[key]["float"] = float_dict[match_key]
|
||||
weight_dict[key]["quantized"] = quantized_dict[key]
|
||||
continue
|
||||
|
||||
# For matching "fc.weight" and "fc._packed_params._packed_params"
|
||||
match_key = _find_match(float_dict, key, "_packed_params")
|
||||
if match_key is not None:
|
||||
weight_dict[key] = {}
|
||||
weight_dict[key]["float"] = float_dict[match_key]
|
||||
weight_dict[key]["quantized"] = quantized_dict[key][0]
|
||||
|
||||
# For LSTM
|
||||
split_str = key.split(".")
|
||||
if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
|
||||
layer = split_str[-2]
|
||||
module_name = ".".join(split_str[:-3])
|
||||
float_weight_ih_key = module_name + ".weight_ih_l" + layer
|
||||
float_weight_hh_key = module_name + ".weight_hh_l" + layer
|
||||
if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
|
||||
weight_dict[key] = {}
|
||||
weight_dict[key]["float"] = float_dict[float_weight_ih_key]
|
||||
weight_dict[key]["quantized"] = (
|
||||
quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
|
||||
)
|
||||
weight_dict[key]["float"] = float_dict[float_weight_hh_key]
|
||||
weight_dict[key]["quantized"] = (
|
||||
quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
|
||||
)
|
||||
|
||||
return weight_dict
|
||||
|
||||
|
||||
def _get_logger_dict_helper(
|
||||
mod: nn.Module, target_dict: Dict[str, Any],
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
r"""This is the helper function for get_logger_dict
|
||||
|
||||
Args:
|
||||
mod: module we want to save all logger stats
|
||||
prefix: prefix for the current module
|
||||
target_dict: the dictionary used to save all logger stats
|
||||
"""
|
||||
|
||||
def get_prefix(prefix):
|
||||
return prefix if prefix == "" else prefix + "."
|
||||
|
||||
for name, child in mod.named_children():
|
||||
if isinstance(child, Logger):
|
||||
target_dict[get_prefix(prefix) + "stats"] = child.stats
|
||||
break
|
||||
|
||||
for name, child in mod.named_children():
|
||||
module_prefix = get_prefix(prefix) + name if prefix else name
|
||||
_get_logger_dict_helper(child, target_dict, module_prefix)
|
||||
|
||||
|
||||
def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
|
||||
r"""Traverse the modules and save all logger stats into target dict.
|
||||
This is mainly used for quantization accuracy debug.
|
||||
|
||||
Type of loggers supported:
|
||||
ShadowLogger: used to log the outputs of the quantized module and its
|
||||
matching float shadow module,
|
||||
OutputLogger: used to log the outputs of the modules
|
||||
|
||||
Args:
|
||||
mod: module we want to save all logger stats
|
||||
prefix: prefix for the current module
|
||||
|
||||
Return:
|
||||
target_dict: the dictionary used to save all logger stats
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict")
|
||||
|
||||
target_dict: Dict[str, Dict] = {}
|
||||
_get_logger_dict_helper(mod, target_dict, prefix)
|
||||
return target_dict
|
||||
|
||||
|
||||
class Logger(nn.Module):
|
||||
r"""Base class for stats logging
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Logger, self).__init__()
|
||||
self.stats = {}
|
||||
# We only insert observer if the op is quantized with static quantization,
|
||||
# which is identified by activation_observer.dtype == quint8. This is needed
|
||||
# when attaching Logger as observer for FX mode
|
||||
self.dtype = torch.quint8
|
||||
|
||||
def forward(self, x):
|
||||
pass
|
||||
|
||||
|
||||
class ShadowLogger(Logger):
|
||||
r"""Class used in Shadow module to record the outputs of the original and
|
||||
shadow modules.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ShadowLogger, self).__init__()
|
||||
self.stats["float"] = []
|
||||
self.stats["quantized"] = []
|
||||
|
||||
def forward(self, x, y):
|
||||
if len(x) > 1:
|
||||
x = x[0]
|
||||
if len(y) > 1:
|
||||
y = y[0]
|
||||
self.stats["quantized"].append(x.detach())
|
||||
self.stats["float"].append(y.detach())
|
||||
|
||||
|
||||
class OutputLogger(Logger):
|
||||
r"""Class used to log the outputs of the module
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(OutputLogger, self).__init__()
|
||||
self.stats["tensor_val"] = []
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
self.stats["tensor_val"].append(x)
|
||||
return x
|
||||
|
||||
|
||||
def _convert_tuple_to_list(t: Any) -> Any:
|
||||
return list(_convert_tuple_to_list(x) for x in t) if type(t) is tuple else t
|
||||
|
||||
|
||||
def _dequantize_tensor_list(t: Any) -> Any:
|
||||
return (
|
||||
list(_dequantize_tensor_list(x) for x in t)
|
||||
if type(t) is list
|
||||
else t.dequantize()
|
||||
if t.is_quantized
|
||||
else t
|
||||
)
|
||||
|
||||
|
||||
class Shadow(nn.Module):
|
||||
r"""Shadow module attaches the float module to its matching quantized module
|
||||
as the shadow. Then it uses Logger module to process the outputs of both
|
||||
modules.
|
||||
|
||||
Args:
|
||||
q_module: module quantized from float_module that we want to shadow
|
||||
float_module: float module used to shadow q_module
|
||||
logger_cls: type of logger used to process the outputs of q_module and
|
||||
float_module. ShadowLogger or custom loggers can be used.
|
||||
"""
|
||||
|
||||
def __init__(self, q_module, float_module, logger_cls):
|
||||
super(Shadow, self).__init__()
|
||||
self.orig_module = q_module
|
||||
self.shadow_module = float_module
|
||||
self.dequant = nnq.DeQuantize()
|
||||
self.logger = logger_cls()
|
||||
|
||||
def forward(self, *x) -> torch.Tensor:
|
||||
xl = _convert_tuple_to_list(x)
|
||||
output = self.orig_module(*xl)
|
||||
xl_float = _dequantize_tensor_list(xl)
|
||||
shadow_output = self.shadow_module(*xl_float)
|
||||
self.logger(output, shadow_output)
|
||||
return output
|
||||
|
||||
def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
output = self.orig_module.add(x, y)
|
||||
x = x.dequantize()
|
||||
y = y.dequantize()
|
||||
shadow_output = self.shadow_module.add(x, y)
|
||||
self.logger(output, shadow_output)
|
||||
return output
|
||||
|
||||
def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
|
||||
output = self.orig_module.add_scalar(x, y)
|
||||
x = x.dequantize()
|
||||
shadow_output = self.shadow_module.add_scalar(x, y)
|
||||
self.logger(output, shadow_output)
|
||||
return output
|
||||
|
||||
def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
output = self.orig_module.mul(x, y)
|
||||
x = x.dequantize()
|
||||
y = y.dequantize()
|
||||
shadow_output = self.shadow_module.mul(x, y)
|
||||
self.logger(output, shadow_output)
|
||||
return output
|
||||
|
||||
def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
|
||||
output = self.orig_module.mul_scalar(x, y)
|
||||
x = x.dequantize()
|
||||
shadow_output = self.shadow_module.mul_scalar(x, y)
|
||||
self.logger(output, shadow_output)
|
||||
return output
|
||||
|
||||
def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
|
||||
output = self.orig_module.cat(x, dim)
|
||||
x = [y.dequantize() for y in x]
|
||||
shadow_output = self.shadow_module.cat(x, dim)
|
||||
self.logger(output, shadow_output)
|
||||
return output
|
||||
|
||||
def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
output = self.orig_module.add_relu(x, y)
|
||||
x = x.dequantize()
|
||||
y = y.dequantize()
|
||||
shadow_output = self.shadow_module.add_relu(x, y)
|
||||
self.logger(output, shadow_output)
|
||||
return output
|
||||
|
||||
|
||||
def prepare_model_with_stubs(
|
||||
float_module: nn.Module, q_module: nn.Module,
|
||||
module_swap_list: Set[type], logger_cls: Callable,
|
||||
) -> None:
|
||||
r"""Prepare the model by attaching the float module to its matching quantized
|
||||
module as the shadow if the float module type is in module_swap_list.
|
||||
|
||||
Example usage:
|
||||
prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
|
||||
q_model(data)
|
||||
ob_dict = get_logger_dict(q_model)
|
||||
|
||||
Args:
|
||||
float_module: float module used to generate the q_module
|
||||
q_module: module quantized from float_module
|
||||
module_swap_list: list of float module types to attach the shadow
|
||||
logger_cls: type of logger to be used in shadow module to process the outputs of
|
||||
quantized module and its float shadow module
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_with_stubs")
|
||||
|
||||
float_module_children = {}
|
||||
for name, mod in float_module.named_children():
|
||||
float_module_children[name] = mod
|
||||
|
||||
reassign = {}
|
||||
for name, mod in q_module.named_children():
|
||||
|
||||
if name not in float_module_children:
|
||||
continue
|
||||
|
||||
float_mod = float_module_children[name]
|
||||
|
||||
if type(float_mod) not in module_swap_list:
|
||||
prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls)
|
||||
|
||||
# Insert shadow module only if the module is not of the same type as
|
||||
# the floating point module
|
||||
if type(float_mod) in module_swap_list and not _is_identical_module_type(mod, float_mod):
|
||||
reassign[name] = Shadow(mod, float_mod, logger_cls)
|
||||
|
||||
for key, value in reassign.items():
|
||||
q_module._modules[key] = value
|
||||
|
||||
def _is_identical_module_type(mod1, mod2):
|
||||
# Compare if two modules have the same dtype
|
||||
mod1_module_types = [type(mod) for mod in mod1.modules()]
|
||||
mod2_module_types = [type(mod) for mod in mod2.modules()]
|
||||
return mod1_module_types == mod2_module_types
|
||||
|
||||
|
||||
|
||||
def compare_model_stub(
|
||||
float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
|
||||
*data, logger_cls=ShadowLogger
|
||||
) -> Dict[str, Dict]:
|
||||
r"""Compare quantized module in a model with its floating point counterpart,
|
||||
feeding both of them the same input. Return a dict with key corresponding to
|
||||
module names and each entry being a dictionary with two keys 'float' and
|
||||
'quantized', containing the output tensors of quantized and its matching
|
||||
float shadow module. This dict can be used to compare and compute the module
|
||||
level quantization error.
|
||||
|
||||
This function first call prepare_model_with_stubs() to swap the quantized
|
||||
module that we want to compare with the Shadow module, which takes quantized
|
||||
module, corresponding float module and logger as input, and creates a forward
|
||||
path inside to make the float module to shadow quantized module sharing the
|
||||
same input. The logger can be customizable, default logger is ShadowLogger
|
||||
and it will save the outputs of the quantized module and float module that
|
||||
can be used to compute the module level quantization error.
|
||||
|
||||
Example usage:
|
||||
module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
|
||||
ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
|
||||
for key in ob_dict:
|
||||
print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
|
||||
|
||||
Args:
|
||||
float_model: float model used to generate the q_model
|
||||
q_model: model quantized from float_model
|
||||
module_swap_list: list of float module types at which shadow modules will
|
||||
be attached.
|
||||
data: input data used to run the prepared q_model
|
||||
logger_cls: type of logger to be used in shadow module to process the outputs of
|
||||
quantized module and its float shadow module
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub")
|
||||
prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls)
|
||||
q_model(*data)
|
||||
ob_dict = get_logger_dict(q_model)
|
||||
return ob_dict
|
||||
|
||||
|
||||
def get_matching_activations(
|
||||
float_module: nn.Module, q_module: nn.Module,
|
||||
) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
r"""Find the matching activation between float and quantized modules.
|
||||
|
||||
Args:
|
||||
float_module: float module used to generate the q_module
|
||||
q_module: module quantized from float_module
|
||||
|
||||
Return:
|
||||
act_dict: dict with key corresponding to quantized module names and each
|
||||
entry being a dictionary with two keys 'float' and 'quantized', containing
|
||||
the matching float and quantized activations
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.get_matching_activations")
|
||||
float_dict = get_logger_dict(float_module)
|
||||
quantized_dict = get_logger_dict(q_module)
|
||||
act_dict: Dict[str, Dict] = {}
|
||||
for key in quantized_dict:
|
||||
match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
|
||||
if match_key is not None:
|
||||
act_dict[key] = {}
|
||||
act_dict[key]["float"] = float_dict[match_key]["tensor_val"]
|
||||
act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"]
|
||||
return act_dict
|
||||
|
||||
|
||||
def prepare_model_outputs(
|
||||
float_module: nn.Module,
|
||||
q_module: nn.Module,
|
||||
logger_cls=OutputLogger,
|
||||
allow_list=None
|
||||
) -> None:
|
||||
r"""Prepare the model by attaching the logger to both float module
|
||||
and quantized module if they are in the allow_list.
|
||||
|
||||
Args:
|
||||
float_module: float module used to generate the q_module
|
||||
q_module: module quantized from float_module
|
||||
logger_cls: type of logger to be attached to float_module and q_module
|
||||
allow_list: list of module types to attach logger
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs")
|
||||
if allow_list is None:
|
||||
allow_list = get_default_compare_output_module_list()
|
||||
|
||||
qconfig_debug = torch.quantization.QConfig(activation=logger_cls, weight=None)
|
||||
float_module.qconfig = qconfig_debug # type: ignore[assignment]
|
||||
prepare(float_module, inplace=True, allow_list=allow_list)
|
||||
q_module.qconfig = qconfig_debug # type: ignore[assignment]
|
||||
prepare(
|
||||
q_module,
|
||||
inplace=True,
|
||||
allow_list=allow_list,
|
||||
observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
|
||||
)
|
||||
|
||||
|
||||
def compare_model_outputs(
|
||||
float_model: nn.Module,
|
||||
q_model: nn.Module,
|
||||
*data,
|
||||
logger_cls=OutputLogger,
|
||||
allow_list=None
|
||||
) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
r"""Compare output activations between float and quantized models at
|
||||
corresponding locations for the same input. Return a dict with key corresponding
|
||||
to quantized module names and each entry being a dictionary with two keys
|
||||
'float' and 'quantized', containing the activations of quantized model and
|
||||
float model at matching locations. This dict can be used to compare and
|
||||
compute the propagation quantization error.
|
||||
|
||||
Example usage:
|
||||
act_compare_dict = compare_model_outputs(float_model, qmodel, data)
|
||||
for key in act_compare_dict:
|
||||
print(key, compute_error(act_compare_dict[key]['float'], act_compare_dict[key]['quantized'].dequantize()))
|
||||
|
||||
Args:
|
||||
float_model: float model used to generate the q_model
|
||||
q_model: model quantized from float_model
|
||||
data: input data used to run the prepared float_model and q_model
|
||||
logger_cls: type of logger to be attached to float_module and q_module
|
||||
allow_list: list of module types to attach logger
|
||||
|
||||
Return:
|
||||
act_compare_dict: dict with key corresponding to quantized module names
|
||||
and each entry being a dictionary with two keys 'float' and 'quantized',
|
||||
containing the matching float and quantized activations
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs")
|
||||
if allow_list is None:
|
||||
allow_list = get_default_compare_output_module_list()
|
||||
prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
|
||||
float_model(*data)
|
||||
q_model(*data)
|
||||
act_compare_dict = get_matching_activations(float_model, q_model)
|
||||
return act_compare_dict
|
||||
|
|
|
|||
|
|
@ -1,513 +1,26 @@
|
|||
import collections
|
||||
# flake8: noqa: F401
|
||||
r"""
|
||||
This file is in the process of migration to `torch/ao/quantization`, and
|
||||
is kept here for compatibility while the migration process is ongoing.
|
||||
If you are adding a new entry/functionality, please, add it to the
|
||||
`torch/ao/ns/_numeric_suite_fx.py`, while adding an import statement
|
||||
here.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.quantization.quantize_fx as quantize_fx
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.graph import Node
|
||||
from torch.quantization.ns.mappings import (
|
||||
get_base_name_to_sets_of_related_ops,
|
||||
from torch.ao.ns._numeric_suite_fx import (
|
||||
RNNReturnType,
|
||||
OutputLogger,
|
||||
NSTracer,
|
||||
_extract_weights_one_model,
|
||||
_extract_weights_impl,
|
||||
extract_weights,
|
||||
_add_loggers_one_model,
|
||||
_add_loggers_impl,
|
||||
add_loggers,
|
||||
_extract_logger_info_one_model,
|
||||
extract_logger_info,
|
||||
_add_shadow_loggers_impl,
|
||||
add_shadow_loggers,
|
||||
extract_shadow_logger_info,
|
||||
extend_logger_results_with_comparison,
|
||||
)
|
||||
from torch.quantization.ns.graph_matcher import (
|
||||
get_matching_subgraph_pairs,
|
||||
get_type_a_related_to_b,
|
||||
)
|
||||
|
||||
from .ns.weight_utils import (
|
||||
extract_weight_from_node,
|
||||
)
|
||||
|
||||
from .ns.graph_passes import (
|
||||
add_loggers_to_model,
|
||||
create_a_shadows_b,
|
||||
)
|
||||
|
||||
from .ns.utils import (
|
||||
rekey_logger_info_on_node_name_of_model,
|
||||
maybe_add_missing_fqns,
|
||||
get_target_type_str,
|
||||
)
|
||||
|
||||
from .ns.ns_types import (
|
||||
NSSingleResultValuesType,
|
||||
NSResultsType,
|
||||
NSNodeTargetType,
|
||||
)
|
||||
|
||||
from typing import Dict, Tuple, Callable, List, Optional, Set
|
||||
|
||||
RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
||||
|
||||
class OutputLogger(nn.Module):
|
||||
stats: List[torch.Tensor]
|
||||
stats_rnn: List[RNNReturnType]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ref_node_name: str,
|
||||
prev_node_name: str,
|
||||
model_name: str,
|
||||
ref_name: str,
|
||||
prev_node_target_type: str,
|
||||
ref_node_target_type: str,
|
||||
results_type: str,
|
||||
index_within_arg: int,
|
||||
index_of_arg: int,
|
||||
fqn: Optional[str],
|
||||
):
|
||||
super().__init__()
|
||||
self.stats: List[torch.Tensor] = []
|
||||
self.stats_rnn: List[RNNReturnType] = []
|
||||
|
||||
# name of the node which was responsible for adding this logger
|
||||
# Note:
|
||||
# - if we are logging node outputs, this is the same as prev_node_name
|
||||
# - if we are logging node inputs, this is the name of the node
|
||||
# whose input this logger is logging.
|
||||
#
|
||||
# example, where logger1 is logging input of op1 and logger2 is logging
|
||||
# the output of op1:
|
||||
#
|
||||
# x1 -> logger1 -> op1 -> logger2 -> x2
|
||||
#
|
||||
# in this example,
|
||||
# - logger1's prev_node_name is x1 and ref_node_name is op1
|
||||
# - logger2's prev_node_name is op1 and ref_node_name is op1
|
||||
self.ref_node_name = ref_node_name
|
||||
# name of the node whose output this Logger is capturing
|
||||
self.prev_node_name = prev_node_name
|
||||
|
||||
# name of the model from which the node originated from
|
||||
self.model_name = model_name
|
||||
# reference name, used to match loggers from separate models
|
||||
# to each other
|
||||
self.ref_name = ref_name
|
||||
# type of the target of the node whose output this logger is logging
|
||||
self.prev_node_target_type = prev_node_target_type
|
||||
# type of the target of the node which was respondible for adding this
|
||||
# logger
|
||||
self.ref_node_target_type = ref_node_target_type
|
||||
# what kind of values are inside of stats
|
||||
self.results_type = results_type
|
||||
# index of this node within the arg of the input/output node
|
||||
# for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
|
||||
self.index_within_arg = index_within_arg
|
||||
# index of this node within the args of the input/output node
|
||||
# for example, in add(x1, x2), x2 would have index_of_arg == 1
|
||||
self.index_of_arg = index_of_arg
|
||||
# fully qualified name
|
||||
self.fqn = fqn
|
||||
|
||||
# Note: cannot annotate the type of x because TorchScript does not support
|
||||
# the Union type.
|
||||
def forward(self, x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
self.stats.append(x.detach())
|
||||
elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
|
||||
new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
|
||||
self.stats_rnn.append(new_res)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"""OutputLogger(ref_name={self.ref_name}, model_name={self.model_name},
|
||||
prev_node_name={self.prev_node_name}, ref_node_name={self.ref_node_name},
|
||||
ref_node_target_type={self.ref_node_target_type}
|
||||
results_type={self.results_type}, index_within_arg={self.index_within_arg},
|
||||
index_of_arg={self.index_of_arg}, fqn={self.fqn})"""
|
||||
|
||||
|
||||
class NSTracer(quantize_fx.QuantizationTracer):
|
||||
"""
|
||||
Just like a regular tracer, but treats observers and fake_quantize
|
||||
modules as leaf modules.
|
||||
"""
|
||||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
||||
if isinstance(m, torch.quantization.ObserverBase):
|
||||
return True
|
||||
elif isinstance(m, torch.quantization.FakeQuantizeBase):
|
||||
return True
|
||||
return super().is_leaf_module(m, module_qualified_name)
|
||||
|
||||
|
||||
def _extract_weights_one_model(
|
||||
model_name: str,
|
||||
model: GraphModule,
|
||||
nodes_and_names_to_instrument: List[Tuple[Node, str]],
|
||||
results: NSResultsType,
|
||||
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
||||
) -> None:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
|
||||
for node, ref_name in nodes_and_names_to_instrument:
|
||||
res_type = NSSingleResultValuesType.WEIGHT.value
|
||||
extracted_weight = extract_weight_from_node(
|
||||
node, model, op_to_type_to_weight_extraction_fn)
|
||||
if extracted_weight:
|
||||
if ref_name not in results:
|
||||
results[ref_name] = {res_type: {}}
|
||||
results[ref_name][res_type][model_name] = [extracted_weight]
|
||||
|
||||
|
||||
def _extract_weights_impl(
|
||||
model_name_a: str,
|
||||
gm_a: GraphModule,
|
||||
model_name_b: str,
|
||||
gm_b: GraphModule,
|
||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
||||
) -> NSResultsType:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
|
||||
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
||||
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
||||
unmatchable_types_map)
|
||||
|
||||
# split the subgraph pairs into one data structure for each model
|
||||
nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
|
||||
nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
|
||||
for match_name, match in matched_subgraph_pairs.items():
|
||||
subgraph_a, subgraph_b = match
|
||||
nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
|
||||
nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
|
||||
|
||||
# populate the results, one model at a time
|
||||
results: NSResultsType = {}
|
||||
_extract_weights_one_model(
|
||||
model_name_a, gm_a, nodes_and_names_to_instrument_a, results,
|
||||
op_to_type_to_weight_extraction_fn)
|
||||
_extract_weights_one_model(
|
||||
model_name_b, gm_b, nodes_and_names_to_instrument_b, results,
|
||||
op_to_type_to_weight_extraction_fn)
|
||||
|
||||
# fill in missing fqn entries
|
||||
maybe_add_missing_fqns(results)
|
||||
|
||||
# rekey on names of nodes in gm_b
|
||||
results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def extract_weights(
|
||||
model_name_a: str,
|
||||
model_a: nn.Module,
|
||||
model_name_b: str,
|
||||
model_b: nn.Module,
|
||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
||||
) -> NSResultsType:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
|
||||
if base_name_to_sets_of_related_ops is None:
|
||||
base_name_to_sets_of_related_ops = \
|
||||
get_base_name_to_sets_of_related_ops()
|
||||
type_a_related_to_b = \
|
||||
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
||||
|
||||
# TODO(future PR): expose these
|
||||
skipped_module_names: List[str] = []
|
||||
skipped_module_classes: List[Callable] = []
|
||||
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
||||
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
||||
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
||||
if hasattr(model_a, '_node_name_to_scope'):
|
||||
gm_a._node_name_to_scope = model_a._node_name_to_scope
|
||||
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
||||
if hasattr(model_b, '_node_name_to_scope'):
|
||||
gm_b._node_name_to_scope = model_b._node_name_to_scope
|
||||
return _extract_weights_impl(
|
||||
model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
|
||||
unmatchable_types_map, op_to_type_to_weight_extraction_fn)
|
||||
|
||||
|
||||
def _add_loggers_one_model(
|
||||
model_name: str,
|
||||
model: GraphModule,
|
||||
nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]],
|
||||
nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]],
|
||||
logger_cls: Callable,
|
||||
) -> nn.Module:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
|
||||
|
||||
# TODO(future PR): do not observe nodes we do not care
|
||||
# about (both fp32, denylist, etc)
|
||||
node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
|
||||
node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
|
||||
for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs:
|
||||
node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type)
|
||||
for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs:
|
||||
node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type)
|
||||
|
||||
model = add_loggers_to_model(
|
||||
model, node_to_instrument_inputs_to_ref_name,
|
||||
node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
|
||||
return model
|
||||
|
||||
|
||||
def _add_loggers_impl(
|
||||
name_a: str,
|
||||
gm_a: GraphModule,
|
||||
name_b: str,
|
||||
gm_b: GraphModule,
|
||||
logger_cls: Callable,
|
||||
should_log_inputs: bool,
|
||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
) -> Tuple[nn.Module, nn.Module]:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
|
||||
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
||||
gm_a, gm_b,
|
||||
base_name_to_sets_of_related_ops, unmatchable_types_map)
|
||||
nodes_and_names_to_instrument_inputs_a = []
|
||||
nodes_and_names_to_instrument_inputs_b = []
|
||||
nodes_and_names_to_instrument_outputs_a = []
|
||||
nodes_and_names_to_instrument_outputs_b = []
|
||||
for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
|
||||
ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
|
||||
ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
|
||||
# Note: for matching inputs we use start_node, such as observing
|
||||
# the input of linear in linear-relu
|
||||
if should_log_inputs:
|
||||
nodes_and_names_to_instrument_inputs_a.append(
|
||||
(subgraph_a.start_node, match_name, ref_node_type_a))
|
||||
nodes_and_names_to_instrument_inputs_b.append(
|
||||
(subgraph_b.start_node, match_name, ref_node_type_b))
|
||||
# Note: for matching activations we always use end_node,
|
||||
# such as observing the output of relu in linear-relu
|
||||
nodes_and_names_to_instrument_outputs_a.append(
|
||||
(subgraph_a.end_node, match_name, ref_node_type_a))
|
||||
nodes_and_names_to_instrument_outputs_b.append(
|
||||
(subgraph_b.end_node, match_name, ref_node_type_b))
|
||||
|
||||
new_model_a = _add_loggers_one_model(
|
||||
name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
|
||||
nodes_and_names_to_instrument_outputs_a, logger_cls)
|
||||
new_model_b = _add_loggers_one_model(
|
||||
name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
|
||||
nodes_and_names_to_instrument_outputs_b, logger_cls)
|
||||
return (new_model_a, new_model_b)
|
||||
|
||||
|
||||
def add_loggers(
|
||||
name_a: str,
|
||||
model_a: nn.Module,
|
||||
name_b: str,
|
||||
model_b: nn.Module,
|
||||
logger_cls: Callable,
|
||||
should_log_inputs : bool = False,
|
||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
) -> Tuple[nn.Module, nn.Module]:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
|
||||
# TODO(future PR): expose these
|
||||
skipped_module_names: List[str] = []
|
||||
skipped_module_classes: List[Callable] = []
|
||||
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
||||
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
||||
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
||||
if hasattr(model_a, '_node_name_to_scope'):
|
||||
gm_a._node_name_to_scope = model_a._node_name_to_scope
|
||||
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
||||
if hasattr(model_b, '_node_name_to_scope'):
|
||||
gm_b._node_name_to_scope = model_b._node_name_to_scope
|
||||
return _add_loggers_impl(
|
||||
name_a, gm_a, name_b, gm_b, logger_cls,
|
||||
should_log_inputs=should_log_inputs,
|
||||
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
||||
unmatchable_types_map=unmatchable_types_map)
|
||||
|
||||
|
||||
def _extract_logger_info_one_model(
|
||||
model: nn.Module,
|
||||
results: NSResultsType,
|
||||
logger_cls: Callable,
|
||||
) -> None:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
|
||||
for gm_name, mod in model.named_modules():
|
||||
# TODO(future PR): better check when scripted
|
||||
is_logger = (
|
||||
isinstance(mod, logger_cls) # type: ignore[arg-type]
|
||||
or (
|
||||
isinstance(mod, torch.jit.RecursiveScriptModule)
|
||||
and mod.original_name == 'OutputLogger'
|
||||
)
|
||||
)
|
||||
if is_logger:
|
||||
key = mod.ref_name
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
assert mod.model_name not in results[key], \
|
||||
f"{mod.model_name} is already present in results"
|
||||
if mod.results_type not in results[key]:
|
||||
results[key][mod.results_type] = {}
|
||||
if mod.model_name not in results[key][mod.results_type]:
|
||||
results[key][mod.results_type][mod.model_name] = []
|
||||
stats_to_use = mod.stats
|
||||
if len(mod.stats_rnn) > 0:
|
||||
stats_to_use = mod.stats_rnn
|
||||
results[key][mod.results_type][mod.model_name].append({
|
||||
'type': mod.results_type,
|
||||
'values': stats_to_use,
|
||||
'ref_node_name': mod.ref_node_name,
|
||||
'ref_node_target_type': mod.ref_node_target_type,
|
||||
'prev_node_name': mod.prev_node_name,
|
||||
'prev_node_target_type': mod.prev_node_target_type,
|
||||
'index_within_arg': mod.index_within_arg,
|
||||
'index_of_arg': mod.index_of_arg,
|
||||
'fqn': mod.fqn,
|
||||
})
|
||||
# ensure the list stays sorted
|
||||
results[key][mod.results_type][mod.model_name].sort(
|
||||
key=lambda res:
|
||||
f"{res['index_of_arg']}:{res['index_within_arg']}"
|
||||
)
|
||||
|
||||
|
||||
# TODO(future PR): align on naming
|
||||
# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
|
||||
def extract_logger_info(
|
||||
model_a: nn.Module,
|
||||
model_b: nn.Module,
|
||||
logger_cls: Callable,
|
||||
model_name_to_use_for_layer_names: str,
|
||||
) -> NSResultsType:
|
||||
"""
|
||||
Same thing as ns.extract_logger_info, but for models prepared with
|
||||
this module.
|
||||
|
||||
TODO(future PR): real docblock
|
||||
|
||||
Output format: NSResultsType
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
|
||||
results: NSResultsType = {}
|
||||
for model in (model_a, model_b):
|
||||
_extract_logger_info_one_model(model, results, logger_cls)
|
||||
# fill in missing fqn entries
|
||||
maybe_add_missing_fqns(results)
|
||||
# rekey on the name of model b
|
||||
results = rekey_logger_info_on_node_name_of_model(
|
||||
results, model_name_to_use_for_layer_names)
|
||||
return results
|
||||
|
||||
|
||||
def _add_shadow_loggers_impl(
|
||||
name_a: str,
|
||||
gm_a: GraphModule,
|
||||
name_b: str,
|
||||
gm_b: GraphModule,
|
||||
logger_cls: Callable,
|
||||
should_log_inputs: bool,
|
||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
) -> nn.Module:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
|
||||
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
||||
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
||||
unmatchable_types_map)
|
||||
gm_a_shadows_b = create_a_shadows_b(
|
||||
name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
|
||||
should_log_inputs=should_log_inputs,
|
||||
node_type_to_io_type_map=node_type_to_io_type_map)
|
||||
return gm_a_shadows_b
|
||||
|
||||
|
||||
def add_shadow_loggers(
|
||||
name_a: str,
|
||||
model_a: nn.Module,
|
||||
name_b: str,
|
||||
model_b: nn.Module,
|
||||
logger_cls: Callable,
|
||||
should_log_inputs: bool = False,
|
||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Same thing as add_loggers, but for an `a_shadows_b` model.
|
||||
TODO(future PR): real docblock
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
|
||||
# TODO(future PR): expose these
|
||||
skipped_module_names: List[str] = []
|
||||
skipped_module_classes: List[Callable] = []
|
||||
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
||||
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
||||
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
||||
if hasattr(model_a, '_node_name_to_scope'):
|
||||
gm_a._node_name_to_scope = model_a._node_name_to_scope
|
||||
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
||||
if hasattr(model_b, '_node_name_to_scope'):
|
||||
gm_b._node_name_to_scope = model_b._node_name_to_scope
|
||||
return _add_shadow_loggers_impl(
|
||||
name_a, gm_a, name_b, gm_b, logger_cls,
|
||||
should_log_inputs=should_log_inputs,
|
||||
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
||||
node_type_to_io_type_map=node_type_to_io_type_map,
|
||||
unmatchable_types_map=unmatchable_types_map)
|
||||
|
||||
|
||||
def extract_shadow_logger_info(
|
||||
model_a_shadows_b: nn.Module,
|
||||
logger_cls: Callable,
|
||||
model_name_to_use_for_layer_names: str,
|
||||
) -> NSResultsType:
|
||||
"""
|
||||
Same thing as extract_logger_info, but for an `a_shadows_b` model.
|
||||
TODO(future PR): real docblock
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
|
||||
results: NSResultsType = collections.defaultdict(dict)
|
||||
_extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
|
||||
# fill in missing fqn entries
|
||||
maybe_add_missing_fqns(results)
|
||||
# rekey on the name of model b
|
||||
results = rekey_logger_info_on_node_name_of_model(
|
||||
results, model_name_to_use_for_layer_names)
|
||||
return dict(results)
|
||||
|
||||
|
||||
def extend_logger_results_with_comparison(
|
||||
results: NSResultsType,
|
||||
model_name_1: str,
|
||||
model_name_2: str,
|
||||
comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
comparison_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Compares the logged values from `model_name_2` against the corresponding
|
||||
values in `model_name_1`, using `comparison_fn`. Records the result
|
||||
in `model_name_2`'s results under `comparison_name`.
|
||||
"""
|
||||
for _, results_type_to_results in results.items():
|
||||
for _, model_name_to_results in results_type_to_results.items():
|
||||
assert model_name_1 in model_name_to_results, \
|
||||
f"{model_name_1} not found in results"
|
||||
assert model_name_2 in model_name_to_results, \
|
||||
f"{model_name_2} not found in results"
|
||||
|
||||
results_1 = model_name_to_results[model_name_1]
|
||||
results_2 = model_name_to_results[model_name_2]
|
||||
|
||||
for result_2 in results_2:
|
||||
index_within_arg_2 = result_2['index_within_arg']
|
||||
index_of_arg_2 = result_2['index_of_arg']
|
||||
# find corresponding result_1
|
||||
result_1 = None
|
||||
for cur_result_1 in results_1:
|
||||
index_within_arg_1 = cur_result_1['index_within_arg']
|
||||
index_of_arg_1 = cur_result_1['index_of_arg']
|
||||
if (
|
||||
(index_within_arg_1 == index_within_arg_2) and
|
||||
(index_of_arg_1 == index_of_arg_2)
|
||||
):
|
||||
result_1 = cur_result_1
|
||||
break
|
||||
assert result_1 is not None
|
||||
|
||||
values_1 = result_1['values']
|
||||
values_2 = result_2['values']
|
||||
result_2[comparison_name] = []
|
||||
for value_1, value_2 in zip(values_1, values_2):
|
||||
comparison_result = comparison_fn(value_1, value_2)
|
||||
result_2[comparison_name].append(comparison_result)
|
||||
|
|
|
|||
|
|
@ -742,8 +742,8 @@ def get_layer_sqnr_dict(model_a: nn.Module, model_b: nn.Module, x: torch.Tensor)
|
|||
model_b: A quantized model
|
||||
x: Inputs to use during calibration
|
||||
"""
|
||||
import torch.quantization._numeric_suite_fx as ns
|
||||
from torch.quantization.ns.mappings import get_unmatchable_types_map
|
||||
import torch.ao.ns._numeric_suite_fx as ns
|
||||
from torch.ao.ns.fx.mappings import get_unmatchable_types_map
|
||||
|
||||
unmatchable_types_map = get_unmatchable_types_map()
|
||||
unmatchable_types_map["funs_unmatchable"].add(torch.mul)
|
||||
|
|
@ -766,7 +766,7 @@ def get_layer_sqnr_dict(model_a: nn.Module, model_b: nn.Module, x: torch.Tensor)
|
|||
ns.extend_logger_results_with_comparison(
|
||||
activation_comparison_dict,
|
||||
'fp32', 'int8',
|
||||
torch.quantization.ns.utils.compute_sqnr, 'sqnr'
|
||||
torch.ao.ns.fx.utils.compute_sqnr, 'sqnr'
|
||||
)
|
||||
|
||||
# Construct a dictionary mapping layer names to the SQNR values
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ try:
|
|||
prepare_qat_fx,
|
||||
convert_fx,
|
||||
)
|
||||
from torch.quantization.ns.ns_types import NSSingleResultValuesType, NSSubgraph
|
||||
from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph
|
||||
from torch.fx.graph import Node
|
||||
from torch.fx import GraphModule
|
||||
HAS_FX = True
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user