mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
torch.ao migration: numeric suite, eager and fx (#64817)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64817 This migrates `torch.quantization._numeric_suite` to `torch.ao.ns._numeric_suite`, and `torch.quantization._numeric_suite_fx` to `torch.ao.ns._numeric_suite_fx`. 1. move the files ``` HG: move eager mode hg mv caffe2/torch/quantization/_numeric_suite.py caffe2/torch/ao/ns/ HG: move fx hg mv caffe2/torch/quantization/_numeric_suite_fx.py caffe2/torch/ao/ns/ hg mv caffe2/torch/quantization/ns/* caffe2/torch/ao/ns/fx/ ``` 2. create new versions of `_numeric_suite.py` and `_numeric_suite_fx.py` with imports 3. update all FB callsites Test Plan: buck test mode/dev //caffe2/test:quantization Reviewed By: z-a-f Differential Revision: D30867538 fbshipit-source-id: 120ee830434ca490c1183a187a518eebcbbaf22c
This commit is contained in:
parent
39f2b9de2a
commit
1577c106dc
|
|
@ -5,7 +5,7 @@ from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
||||||
|
|
||||||
from torch.quantization import default_qconfig
|
from torch.quantization import default_qconfig
|
||||||
from torch.quantization import QuantWrapper
|
from torch.quantization import QuantWrapper
|
||||||
import torch.quantization._numeric_suite as ns
|
import torch.ao.ns._numeric_suite as ns
|
||||||
|
|
||||||
from torch.quantization._correct_bias import (
|
from torch.quantization._correct_bias import (
|
||||||
_supported_modules,
|
_supported_modules,
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from torch.quantization import (
|
||||||
quantize,
|
quantize,
|
||||||
quantize_dynamic,
|
quantize_dynamic,
|
||||||
)
|
)
|
||||||
from torch.quantization._numeric_suite import (
|
from torch.ao.ns._numeric_suite import (
|
||||||
OutputLogger,
|
OutputLogger,
|
||||||
Shadow,
|
Shadow,
|
||||||
ShadowLogger,
|
ShadowLogger,
|
||||||
|
|
|
||||||
|
|
@ -34,29 +34,29 @@ from torch.quantization.quantization_mappings import (
|
||||||
from torch.testing._internal.common_quantization import NodeSpec as ns
|
from torch.testing._internal.common_quantization import NodeSpec as ns
|
||||||
from torch.quantization.fx.pattern_utils import get_default_quant_patterns
|
from torch.quantization.fx.pattern_utils import get_default_quant_patterns
|
||||||
import torch.quantization.fx.quantization_patterns as qp
|
import torch.quantization.fx.quantization_patterns as qp
|
||||||
from torch.quantization.ns.pattern_utils import (
|
from torch.ao.ns.fx.pattern_utils import (
|
||||||
get_type_a_related_to_b,
|
get_type_a_related_to_b,
|
||||||
)
|
)
|
||||||
from torch.quantization.ns.graph_matcher import (
|
from torch.ao.ns.fx.graph_matcher import (
|
||||||
get_matching_subgraph_pairs,
|
get_matching_subgraph_pairs,
|
||||||
GraphMatchingException,
|
GraphMatchingException,
|
||||||
)
|
)
|
||||||
from torch.quantization.ns.utils import (
|
from torch.ao.ns.fx.utils import (
|
||||||
compute_sqnr,
|
compute_sqnr,
|
||||||
compute_normalized_l2_error,
|
compute_normalized_l2_error,
|
||||||
compute_cosine_similarity,
|
compute_cosine_similarity,
|
||||||
)
|
)
|
||||||
from torch.quantization.ns.mappings import (
|
from torch.ao.ns.fx.mappings import (
|
||||||
get_node_type_to_io_type_map,
|
get_node_type_to_io_type_map,
|
||||||
get_unmatchable_types_map,
|
get_unmatchable_types_map,
|
||||||
get_base_name_to_sets_of_related_ops,
|
get_base_name_to_sets_of_related_ops,
|
||||||
get_base_name_for_op,
|
get_base_name_for_op,
|
||||||
add_op_to_sets_of_related_ops,
|
add_op_to_sets_of_related_ops,
|
||||||
)
|
)
|
||||||
from torch.quantization.ns.weight_utils import (
|
from torch.ao.ns.fx.weight_utils import (
|
||||||
get_op_to_type_to_weight_extraction_fn,
|
get_op_to_type_to_weight_extraction_fn,
|
||||||
)
|
)
|
||||||
from torch.quantization._numeric_suite_fx import (
|
from torch.ao.ns._numeric_suite_fx import (
|
||||||
extract_weights,
|
extract_weights,
|
||||||
_extract_weights_impl,
|
_extract_weights_impl,
|
||||||
add_loggers,
|
add_loggers,
|
||||||
|
|
@ -1634,7 +1634,7 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
||||||
op_to_type_to_weight_extraction_fn = \
|
op_to_type_to_weight_extraction_fn = \
|
||||||
get_op_to_type_to_weight_extraction_fn()
|
get_op_to_type_to_weight_extraction_fn()
|
||||||
op_to_type_to_weight_extraction_fn['call_function'][_wrapped_linear] = \
|
op_to_type_to_weight_extraction_fn['call_function'][_wrapped_linear] = \
|
||||||
torch.quantization.ns.weight_utils.get_linear_fun_weight
|
torch.ao.ns.fx.weight_utils.get_linear_fun_weight
|
||||||
|
|
||||||
# test compare weights
|
# test compare weights
|
||||||
results = extract_weights(
|
results = extract_weights(
|
||||||
|
|
|
||||||
486
torch/ao/ns/_numeric_suite.py
Normal file
486
torch/ao/ns/_numeric_suite.py
Normal file
|
|
@ -0,0 +1,486 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.quantized as nnq
|
||||||
|
import torch.nn.quantized.dynamic as nnqd
|
||||||
|
from torch.quantization import prepare
|
||||||
|
from typing import Dict, List, Optional, Any, Union, Callable, Set
|
||||||
|
|
||||||
|
from torch.quantization.quantization_mappings import (
|
||||||
|
get_default_compare_output_module_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
|
||||||
|
nnqd.Linear,
|
||||||
|
nnq.Linear,
|
||||||
|
nnqd.LSTM,
|
||||||
|
nn.LSTM,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _find_match(
|
||||||
|
str_list: Union[Dict[str, Any], List[str]], key_str: str,
|
||||||
|
postfix: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
split_str = key_str.split(".")
|
||||||
|
if split_str[-1] == postfix:
|
||||||
|
match_string = "".join(key_str.split(".")[0:-1])
|
||||||
|
for s2 in str_list:
|
||||||
|
pattern1 = "".join(s2.split(".")[0:-1])
|
||||||
|
pattern2 = "".join(s2.split(".")[0:-2])
|
||||||
|
if match_string == pattern1:
|
||||||
|
return s2
|
||||||
|
if match_string == pattern2:
|
||||||
|
return s2
|
||||||
|
|
||||||
|
# For matching "fc.weight" and "fc._packed_params._packed_params"
|
||||||
|
if postfix == "_packed_params":
|
||||||
|
match_string = "".join(key_str.split(".")[0:-2])
|
||||||
|
if len(match_string) == 0:
|
||||||
|
return None
|
||||||
|
for s2 in str_list:
|
||||||
|
pattern1 = "".join(s2.split(".")[0:-1])
|
||||||
|
pattern2 = "".join(s2.split(".")[0:-2])
|
||||||
|
if match_string == pattern1:
|
||||||
|
return s2
|
||||||
|
if match_string == pattern2:
|
||||||
|
return s2
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def compare_weights(
|
||||||
|
float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
|
||||||
|
) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||||
|
r"""Compare the weights of the float module with its corresponding quantized
|
||||||
|
module. Return a dict with key corresponding to module names and each entry being
|
||||||
|
a dictionary with two keys 'float' and 'quantized', containing the float and
|
||||||
|
quantized weights. This dict can be used to compare and compute the quantization
|
||||||
|
error of the weights of float and quantized models.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict())
|
||||||
|
for key in wt_compare_dict:
|
||||||
|
print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
float_dict: state dict of the float model
|
||||||
|
quantized_dict: state dict of the quantized model
|
||||||
|
|
||||||
|
Return:
|
||||||
|
weight_dict: dict with key corresponding to module names and each entry being
|
||||||
|
a dictionary with two keys 'float' and 'quantized', containing the float and
|
||||||
|
quantized weights
|
||||||
|
"""
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights")
|
||||||
|
weight_dict: Dict[str, Dict] = {}
|
||||||
|
for key in quantized_dict:
|
||||||
|
match_key = _find_match(float_dict, key, "weight")
|
||||||
|
if match_key is not None:
|
||||||
|
weight_dict[key] = {}
|
||||||
|
weight_dict[key]["float"] = float_dict[match_key]
|
||||||
|
weight_dict[key]["quantized"] = quantized_dict[key]
|
||||||
|
continue
|
||||||
|
|
||||||
|
# For matching "fc.weight" and "fc._packed_params._packed_params"
|
||||||
|
match_key = _find_match(float_dict, key, "_packed_params")
|
||||||
|
if match_key is not None:
|
||||||
|
weight_dict[key] = {}
|
||||||
|
weight_dict[key]["float"] = float_dict[match_key]
|
||||||
|
weight_dict[key]["quantized"] = quantized_dict[key][0]
|
||||||
|
|
||||||
|
# For LSTM
|
||||||
|
split_str = key.split(".")
|
||||||
|
if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
|
||||||
|
layer = split_str[-2]
|
||||||
|
module_name = ".".join(split_str[:-3])
|
||||||
|
float_weight_ih_key = module_name + ".weight_ih_l" + layer
|
||||||
|
float_weight_hh_key = module_name + ".weight_hh_l" + layer
|
||||||
|
if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
|
||||||
|
weight_dict[key] = {}
|
||||||
|
weight_dict[key]["float"] = float_dict[float_weight_ih_key]
|
||||||
|
weight_dict[key]["quantized"] = (
|
||||||
|
quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
|
||||||
|
)
|
||||||
|
weight_dict[key]["float"] = float_dict[float_weight_hh_key]
|
||||||
|
weight_dict[key]["quantized"] = (
|
||||||
|
quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
|
||||||
|
)
|
||||||
|
|
||||||
|
return weight_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _get_logger_dict_helper(
|
||||||
|
mod: nn.Module, target_dict: Dict[str, Any],
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
r"""This is the helper function for get_logger_dict
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mod: module we want to save all logger stats
|
||||||
|
prefix: prefix for the current module
|
||||||
|
target_dict: the dictionary used to save all logger stats
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_prefix(prefix):
|
||||||
|
return prefix if prefix == "" else prefix + "."
|
||||||
|
|
||||||
|
for name, child in mod.named_children():
|
||||||
|
if isinstance(child, Logger):
|
||||||
|
target_dict[get_prefix(prefix) + "stats"] = child.stats
|
||||||
|
break
|
||||||
|
|
||||||
|
for name, child in mod.named_children():
|
||||||
|
module_prefix = get_prefix(prefix) + name if prefix else name
|
||||||
|
_get_logger_dict_helper(child, target_dict, module_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
|
||||||
|
r"""Traverse the modules and save all logger stats into target dict.
|
||||||
|
This is mainly used for quantization accuracy debug.
|
||||||
|
|
||||||
|
Type of loggers supported:
|
||||||
|
ShadowLogger: used to log the outputs of the quantized module and its
|
||||||
|
matching float shadow module,
|
||||||
|
OutputLogger: used to log the outputs of the modules
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mod: module we want to save all logger stats
|
||||||
|
prefix: prefix for the current module
|
||||||
|
|
||||||
|
Return:
|
||||||
|
target_dict: the dictionary used to save all logger stats
|
||||||
|
"""
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict")
|
||||||
|
|
||||||
|
target_dict: Dict[str, Dict] = {}
|
||||||
|
_get_logger_dict_helper(mod, target_dict, prefix)
|
||||||
|
return target_dict
|
||||||
|
|
||||||
|
|
||||||
|
class Logger(nn.Module):
|
||||||
|
r"""Base class for stats logging
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Logger, self).__init__()
|
||||||
|
self.stats = {}
|
||||||
|
# We only insert observer if the op is quantized with static quantization,
|
||||||
|
# which is identified by activation_observer.dtype == quint8. This is needed
|
||||||
|
# when attaching Logger as observer for FX mode
|
||||||
|
self.dtype = torch.quint8
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ShadowLogger(Logger):
|
||||||
|
r"""Class used in Shadow module to record the outputs of the original and
|
||||||
|
shadow modules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(ShadowLogger, self).__init__()
|
||||||
|
self.stats["float"] = []
|
||||||
|
self.stats["quantized"] = []
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
if len(x) > 1:
|
||||||
|
x = x[0]
|
||||||
|
if len(y) > 1:
|
||||||
|
y = y[0]
|
||||||
|
self.stats["quantized"].append(x.detach())
|
||||||
|
self.stats["float"].append(y.detach())
|
||||||
|
|
||||||
|
|
||||||
|
class OutputLogger(Logger):
|
||||||
|
r"""Class used to log the outputs of the module
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(OutputLogger, self).__init__()
|
||||||
|
self.stats["tensor_val"] = []
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
self.stats["tensor_val"].append(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_tuple_to_list(t: Any) -> Any:
|
||||||
|
return list(_convert_tuple_to_list(x) for x in t) if type(t) is tuple else t
|
||||||
|
|
||||||
|
|
||||||
|
def _dequantize_tensor_list(t: Any) -> Any:
|
||||||
|
return (
|
||||||
|
list(_dequantize_tensor_list(x) for x in t)
|
||||||
|
if type(t) is list
|
||||||
|
else t.dequantize()
|
||||||
|
if t.is_quantized
|
||||||
|
else t
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Shadow(nn.Module):
|
||||||
|
r"""Shadow module attaches the float module to its matching quantized module
|
||||||
|
as the shadow. Then it uses Logger module to process the outputs of both
|
||||||
|
modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q_module: module quantized from float_module that we want to shadow
|
||||||
|
float_module: float module used to shadow q_module
|
||||||
|
logger_cls: type of logger used to process the outputs of q_module and
|
||||||
|
float_module. ShadowLogger or custom loggers can be used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, q_module, float_module, logger_cls):
|
||||||
|
super(Shadow, self).__init__()
|
||||||
|
self.orig_module = q_module
|
||||||
|
self.shadow_module = float_module
|
||||||
|
self.dequant = nnq.DeQuantize()
|
||||||
|
self.logger = logger_cls()
|
||||||
|
|
||||||
|
def forward(self, *x) -> torch.Tensor:
|
||||||
|
xl = _convert_tuple_to_list(x)
|
||||||
|
output = self.orig_module(*xl)
|
||||||
|
xl_float = _dequantize_tensor_list(xl)
|
||||||
|
shadow_output = self.shadow_module(*xl_float)
|
||||||
|
self.logger(output, shadow_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
output = self.orig_module.add(x, y)
|
||||||
|
x = x.dequantize()
|
||||||
|
y = y.dequantize()
|
||||||
|
shadow_output = self.shadow_module.add(x, y)
|
||||||
|
self.logger(output, shadow_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
|
||||||
|
output = self.orig_module.add_scalar(x, y)
|
||||||
|
x = x.dequantize()
|
||||||
|
shadow_output = self.shadow_module.add_scalar(x, y)
|
||||||
|
self.logger(output, shadow_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
output = self.orig_module.mul(x, y)
|
||||||
|
x = x.dequantize()
|
||||||
|
y = y.dequantize()
|
||||||
|
shadow_output = self.shadow_module.mul(x, y)
|
||||||
|
self.logger(output, shadow_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
|
||||||
|
output = self.orig_module.mul_scalar(x, y)
|
||||||
|
x = x.dequantize()
|
||||||
|
shadow_output = self.shadow_module.mul_scalar(x, y)
|
||||||
|
self.logger(output, shadow_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
|
||||||
|
output = self.orig_module.cat(x, dim)
|
||||||
|
x = [y.dequantize() for y in x]
|
||||||
|
shadow_output = self.shadow_module.cat(x, dim)
|
||||||
|
self.logger(output, shadow_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
output = self.orig_module.add_relu(x, y)
|
||||||
|
x = x.dequantize()
|
||||||
|
y = y.dequantize()
|
||||||
|
shadow_output = self.shadow_module.add_relu(x, y)
|
||||||
|
self.logger(output, shadow_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_model_with_stubs(
|
||||||
|
float_module: nn.Module, q_module: nn.Module,
|
||||||
|
module_swap_list: Set[type], logger_cls: Callable,
|
||||||
|
) -> None:
|
||||||
|
r"""Prepare the model by attaching the float module to its matching quantized
|
||||||
|
module as the shadow if the float module type is in module_swap_list.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
|
||||||
|
q_model(data)
|
||||||
|
ob_dict = get_logger_dict(q_model)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
float_module: float module used to generate the q_module
|
||||||
|
q_module: module quantized from float_module
|
||||||
|
module_swap_list: list of float module types to attach the shadow
|
||||||
|
logger_cls: type of logger to be used in shadow module to process the outputs of
|
||||||
|
quantized module and its float shadow module
|
||||||
|
"""
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_with_stubs")
|
||||||
|
|
||||||
|
float_module_children = {}
|
||||||
|
for name, mod in float_module.named_children():
|
||||||
|
float_module_children[name] = mod
|
||||||
|
|
||||||
|
reassign = {}
|
||||||
|
for name, mod in q_module.named_children():
|
||||||
|
|
||||||
|
if name not in float_module_children:
|
||||||
|
continue
|
||||||
|
|
||||||
|
float_mod = float_module_children[name]
|
||||||
|
|
||||||
|
if type(float_mod) not in module_swap_list:
|
||||||
|
prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls)
|
||||||
|
|
||||||
|
# Insert shadow module only if the module is not of the same type as
|
||||||
|
# the floating point module
|
||||||
|
if type(float_mod) in module_swap_list and not _is_identical_module_type(mod, float_mod):
|
||||||
|
reassign[name] = Shadow(mod, float_mod, logger_cls)
|
||||||
|
|
||||||
|
for key, value in reassign.items():
|
||||||
|
q_module._modules[key] = value
|
||||||
|
|
||||||
|
def _is_identical_module_type(mod1, mod2):
|
||||||
|
# Compare if two modules have the same dtype
|
||||||
|
mod1_module_types = [type(mod) for mod in mod1.modules()]
|
||||||
|
mod2_module_types = [type(mod) for mod in mod2.modules()]
|
||||||
|
return mod1_module_types == mod2_module_types
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def compare_model_stub(
|
||||||
|
float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
|
||||||
|
*data, logger_cls=ShadowLogger
|
||||||
|
) -> Dict[str, Dict]:
|
||||||
|
r"""Compare quantized module in a model with its floating point counterpart,
|
||||||
|
feeding both of them the same input. Return a dict with key corresponding to
|
||||||
|
module names and each entry being a dictionary with two keys 'float' and
|
||||||
|
'quantized', containing the output tensors of quantized and its matching
|
||||||
|
float shadow module. This dict can be used to compare and compute the module
|
||||||
|
level quantization error.
|
||||||
|
|
||||||
|
This function first call prepare_model_with_stubs() to swap the quantized
|
||||||
|
module that we want to compare with the Shadow module, which takes quantized
|
||||||
|
module, corresponding float module and logger as input, and creates a forward
|
||||||
|
path inside to make the float module to shadow quantized module sharing the
|
||||||
|
same input. The logger can be customizable, default logger is ShadowLogger
|
||||||
|
and it will save the outputs of the quantized module and float module that
|
||||||
|
can be used to compute the module level quantization error.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
|
||||||
|
ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
|
||||||
|
for key in ob_dict:
|
||||||
|
print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
float_model: float model used to generate the q_model
|
||||||
|
q_model: model quantized from float_model
|
||||||
|
module_swap_list: list of float module types at which shadow modules will
|
||||||
|
be attached.
|
||||||
|
data: input data used to run the prepared q_model
|
||||||
|
logger_cls: type of logger to be used in shadow module to process the outputs of
|
||||||
|
quantized module and its float shadow module
|
||||||
|
"""
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub")
|
||||||
|
prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls)
|
||||||
|
q_model(*data)
|
||||||
|
ob_dict = get_logger_dict(q_model)
|
||||||
|
return ob_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_matching_activations(
|
||||||
|
float_module: nn.Module, q_module: nn.Module,
|
||||||
|
) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||||
|
r"""Find the matching activation between float and quantized modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
float_module: float module used to generate the q_module
|
||||||
|
q_module: module quantized from float_module
|
||||||
|
|
||||||
|
Return:
|
||||||
|
act_dict: dict with key corresponding to quantized module names and each
|
||||||
|
entry being a dictionary with two keys 'float' and 'quantized', containing
|
||||||
|
the matching float and quantized activations
|
||||||
|
"""
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite.get_matching_activations")
|
||||||
|
float_dict = get_logger_dict(float_module)
|
||||||
|
quantized_dict = get_logger_dict(q_module)
|
||||||
|
act_dict: Dict[str, Dict] = {}
|
||||||
|
for key in quantized_dict:
|
||||||
|
match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
|
||||||
|
if match_key is not None:
|
||||||
|
act_dict[key] = {}
|
||||||
|
act_dict[key]["float"] = float_dict[match_key]["tensor_val"]
|
||||||
|
act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"]
|
||||||
|
return act_dict
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_model_outputs(
|
||||||
|
float_module: nn.Module,
|
||||||
|
q_module: nn.Module,
|
||||||
|
logger_cls=OutputLogger,
|
||||||
|
allow_list=None
|
||||||
|
) -> None:
|
||||||
|
r"""Prepare the model by attaching the logger to both float module
|
||||||
|
and quantized module if they are in the allow_list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
float_module: float module used to generate the q_module
|
||||||
|
q_module: module quantized from float_module
|
||||||
|
logger_cls: type of logger to be attached to float_module and q_module
|
||||||
|
allow_list: list of module types to attach logger
|
||||||
|
"""
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs")
|
||||||
|
if allow_list is None:
|
||||||
|
allow_list = get_default_compare_output_module_list()
|
||||||
|
|
||||||
|
qconfig_debug = torch.quantization.QConfig(activation=logger_cls, weight=None)
|
||||||
|
float_module.qconfig = qconfig_debug # type: ignore[assignment]
|
||||||
|
prepare(float_module, inplace=True, allow_list=allow_list)
|
||||||
|
q_module.qconfig = qconfig_debug # type: ignore[assignment]
|
||||||
|
prepare(
|
||||||
|
q_module,
|
||||||
|
inplace=True,
|
||||||
|
allow_list=allow_list,
|
||||||
|
observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compare_model_outputs(
|
||||||
|
float_model: nn.Module,
|
||||||
|
q_model: nn.Module,
|
||||||
|
*data,
|
||||||
|
logger_cls=OutputLogger,
|
||||||
|
allow_list=None
|
||||||
|
) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||||
|
r"""Compare output activations between float and quantized models at
|
||||||
|
corresponding locations for the same input. Return a dict with key corresponding
|
||||||
|
to quantized module names and each entry being a dictionary with two keys
|
||||||
|
'float' and 'quantized', containing the activations of quantized model and
|
||||||
|
float model at matching locations. This dict can be used to compare and
|
||||||
|
compute the propagation quantization error.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
act_compare_dict = compare_model_outputs(float_model, qmodel, data)
|
||||||
|
for key in act_compare_dict:
|
||||||
|
print(key, compute_error(act_compare_dict[key]['float'], act_compare_dict[key]['quantized'].dequantize()))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
float_model: float model used to generate the q_model
|
||||||
|
q_model: model quantized from float_model
|
||||||
|
data: input data used to run the prepared float_model and q_model
|
||||||
|
logger_cls: type of logger to be attached to float_module and q_module
|
||||||
|
allow_list: list of module types to attach logger
|
||||||
|
|
||||||
|
Return:
|
||||||
|
act_compare_dict: dict with key corresponding to quantized module names
|
||||||
|
and each entry being a dictionary with two keys 'float' and 'quantized',
|
||||||
|
containing the matching float and quantized activations
|
||||||
|
"""
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs")
|
||||||
|
if allow_list is None:
|
||||||
|
allow_list = get_default_compare_output_module_list()
|
||||||
|
prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
|
||||||
|
float_model(*data)
|
||||||
|
q_model(*data)
|
||||||
|
act_compare_dict = get_matching_activations(float_model, q_model)
|
||||||
|
return act_compare_dict
|
||||||
513
torch/ao/ns/_numeric_suite_fx.py
Normal file
513
torch/ao/ns/_numeric_suite_fx.py
Normal file
|
|
@ -0,0 +1,513 @@
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.quantization.quantize_fx as quantize_fx
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
from torch.fx.graph import Node
|
||||||
|
from torch.ao.ns.fx.mappings import (
|
||||||
|
get_base_name_to_sets_of_related_ops,
|
||||||
|
)
|
||||||
|
from torch.ao.ns.fx.graph_matcher import (
|
||||||
|
get_matching_subgraph_pairs,
|
||||||
|
get_type_a_related_to_b,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .fx.weight_utils import (
|
||||||
|
extract_weight_from_node,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .fx.graph_passes import (
|
||||||
|
add_loggers_to_model,
|
||||||
|
create_a_shadows_b,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .fx.utils import (
|
||||||
|
rekey_logger_info_on_node_name_of_model,
|
||||||
|
maybe_add_missing_fqns,
|
||||||
|
get_target_type_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .fx.ns_types import (
|
||||||
|
NSSingleResultValuesType,
|
||||||
|
NSResultsType,
|
||||||
|
NSNodeTargetType,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing import Dict, Tuple, Callable, List, Optional, Set
|
||||||
|
|
||||||
|
RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
||||||
|
|
||||||
|
class OutputLogger(nn.Module):
|
||||||
|
stats: List[torch.Tensor]
|
||||||
|
stats_rnn: List[RNNReturnType]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ref_node_name: str,
|
||||||
|
prev_node_name: str,
|
||||||
|
model_name: str,
|
||||||
|
ref_name: str,
|
||||||
|
prev_node_target_type: str,
|
||||||
|
ref_node_target_type: str,
|
||||||
|
results_type: str,
|
||||||
|
index_within_arg: int,
|
||||||
|
index_of_arg: int,
|
||||||
|
fqn: Optional[str],
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.stats: List[torch.Tensor] = []
|
||||||
|
self.stats_rnn: List[RNNReturnType] = []
|
||||||
|
|
||||||
|
# name of the node which was responsible for adding this logger
|
||||||
|
# Note:
|
||||||
|
# - if we are logging node outputs, this is the same as prev_node_name
|
||||||
|
# - if we are logging node inputs, this is the name of the node
|
||||||
|
# whose input this logger is logging.
|
||||||
|
#
|
||||||
|
# example, where logger1 is logging input of op1 and logger2 is logging
|
||||||
|
# the output of op1:
|
||||||
|
#
|
||||||
|
# x1 -> logger1 -> op1 -> logger2 -> x2
|
||||||
|
#
|
||||||
|
# in this example,
|
||||||
|
# - logger1's prev_node_name is x1 and ref_node_name is op1
|
||||||
|
# - logger2's prev_node_name is op1 and ref_node_name is op1
|
||||||
|
self.ref_node_name = ref_node_name
|
||||||
|
# name of the node whose output this Logger is capturing
|
||||||
|
self.prev_node_name = prev_node_name
|
||||||
|
|
||||||
|
# name of the model from which the node originated from
|
||||||
|
self.model_name = model_name
|
||||||
|
# reference name, used to match loggers from separate models
|
||||||
|
# to each other
|
||||||
|
self.ref_name = ref_name
|
||||||
|
# type of the target of the node whose output this logger is logging
|
||||||
|
self.prev_node_target_type = prev_node_target_type
|
||||||
|
# type of the target of the node which was respondible for adding this
|
||||||
|
# logger
|
||||||
|
self.ref_node_target_type = ref_node_target_type
|
||||||
|
# what kind of values are inside of stats
|
||||||
|
self.results_type = results_type
|
||||||
|
# index of this node within the arg of the input/output node
|
||||||
|
# for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
|
||||||
|
self.index_within_arg = index_within_arg
|
||||||
|
# index of this node within the args of the input/output node
|
||||||
|
# for example, in add(x1, x2), x2 would have index_of_arg == 1
|
||||||
|
self.index_of_arg = index_of_arg
|
||||||
|
# fully qualified name
|
||||||
|
self.fqn = fqn
|
||||||
|
|
||||||
|
# Note: cannot annotate the type of x because TorchScript does not support
|
||||||
|
# the Union type.
|
||||||
|
def forward(self, x):
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
self.stats.append(x.detach())
|
||||||
|
elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
|
||||||
|
new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
|
||||||
|
self.stats_rnn.append(new_res)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"""OutputLogger(ref_name={self.ref_name}, model_name={self.model_name},
|
||||||
|
prev_node_name={self.prev_node_name}, ref_node_name={self.ref_node_name},
|
||||||
|
ref_node_target_type={self.ref_node_target_type}
|
||||||
|
results_type={self.results_type}, index_within_arg={self.index_within_arg},
|
||||||
|
index_of_arg={self.index_of_arg}, fqn={self.fqn})"""
|
||||||
|
|
||||||
|
|
||||||
|
class NSTracer(quantize_fx.QuantizationTracer):
|
||||||
|
"""
|
||||||
|
Just like a regular tracer, but treats observers and fake_quantize
|
||||||
|
modules as leaf modules.
|
||||||
|
"""
|
||||||
|
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
||||||
|
if isinstance(m, torch.quantization.ObserverBase):
|
||||||
|
return True
|
||||||
|
elif isinstance(m, torch.quantization.FakeQuantizeBase):
|
||||||
|
return True
|
||||||
|
return super().is_leaf_module(m, module_qualified_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_weights_one_model(
|
||||||
|
model_name: str,
|
||||||
|
model: GraphModule,
|
||||||
|
nodes_and_names_to_instrument: List[Tuple[Node, str]],
|
||||||
|
results: NSResultsType,
|
||||||
|
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
||||||
|
) -> None:
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
|
||||||
|
for node, ref_name in nodes_and_names_to_instrument:
|
||||||
|
res_type = NSSingleResultValuesType.WEIGHT.value
|
||||||
|
extracted_weight = extract_weight_from_node(
|
||||||
|
node, model, op_to_type_to_weight_extraction_fn)
|
||||||
|
if extracted_weight:
|
||||||
|
if ref_name not in results:
|
||||||
|
results[ref_name] = {res_type: {}}
|
||||||
|
results[ref_name][res_type][model_name] = [extracted_weight]
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_weights_impl(
|
||||||
|
model_name_a: str,
|
||||||
|
gm_a: GraphModule,
|
||||||
|
model_name_b: str,
|
||||||
|
gm_b: GraphModule,
|
||||||
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||||
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||||
|
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
||||||
|
) -> NSResultsType:
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
|
||||||
|
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
||||||
|
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
||||||
|
unmatchable_types_map)
|
||||||
|
|
||||||
|
# split the subgraph pairs into one data structure for each model
|
||||||
|
nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
|
||||||
|
nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
|
||||||
|
for match_name, match in matched_subgraph_pairs.items():
|
||||||
|
subgraph_a, subgraph_b = match
|
||||||
|
nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
|
||||||
|
nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
|
||||||
|
|
||||||
|
# populate the results, one model at a time
|
||||||
|
results: NSResultsType = {}
|
||||||
|
_extract_weights_one_model(
|
||||||
|
model_name_a, gm_a, nodes_and_names_to_instrument_a, results,
|
||||||
|
op_to_type_to_weight_extraction_fn)
|
||||||
|
_extract_weights_one_model(
|
||||||
|
model_name_b, gm_b, nodes_and_names_to_instrument_b, results,
|
||||||
|
op_to_type_to_weight_extraction_fn)
|
||||||
|
|
||||||
|
# fill in missing fqn entries
|
||||||
|
maybe_add_missing_fqns(results)
|
||||||
|
|
||||||
|
# rekey on names of nodes in gm_b
|
||||||
|
results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def extract_weights(
|
||||||
|
model_name_a: str,
|
||||||
|
model_a: nn.Module,
|
||||||
|
model_name_b: str,
|
||||||
|
model_b: nn.Module,
|
||||||
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||||
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||||
|
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
||||||
|
) -> NSResultsType:
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
|
||||||
|
if base_name_to_sets_of_related_ops is None:
|
||||||
|
base_name_to_sets_of_related_ops = \
|
||||||
|
get_base_name_to_sets_of_related_ops()
|
||||||
|
type_a_related_to_b = \
|
||||||
|
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
||||||
|
|
||||||
|
# TODO(future PR): expose these
|
||||||
|
skipped_module_names: List[str] = []
|
||||||
|
skipped_module_classes: List[Callable] = []
|
||||||
|
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
||||||
|
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
||||||
|
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
||||||
|
if hasattr(model_a, '_node_name_to_scope'):
|
||||||
|
gm_a._node_name_to_scope = model_a._node_name_to_scope
|
||||||
|
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
||||||
|
if hasattr(model_b, '_node_name_to_scope'):
|
||||||
|
gm_b._node_name_to_scope = model_b._node_name_to_scope
|
||||||
|
return _extract_weights_impl(
|
||||||
|
model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
|
||||||
|
unmatchable_types_map, op_to_type_to_weight_extraction_fn)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_loggers_one_model(
|
||||||
|
model_name: str,
|
||||||
|
model: GraphModule,
|
||||||
|
nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]],
|
||||||
|
nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]],
|
||||||
|
logger_cls: Callable,
|
||||||
|
) -> nn.Module:
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
|
||||||
|
|
||||||
|
# TODO(future PR): do not observe nodes we do not care
|
||||||
|
# about (both fp32, denylist, etc)
|
||||||
|
node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
|
||||||
|
node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
|
||||||
|
for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs:
|
||||||
|
node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type)
|
||||||
|
for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs:
|
||||||
|
node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type)
|
||||||
|
|
||||||
|
model = add_loggers_to_model(
|
||||||
|
model, node_to_instrument_inputs_to_ref_name,
|
||||||
|
node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _add_loggers_impl(
|
||||||
|
name_a: str,
|
||||||
|
gm_a: GraphModule,
|
||||||
|
name_b: str,
|
||||||
|
gm_b: GraphModule,
|
||||||
|
logger_cls: Callable,
|
||||||
|
should_log_inputs: bool,
|
||||||
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||||
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||||
|
) -> Tuple[nn.Module, nn.Module]:
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
|
||||||
|
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
||||||
|
gm_a, gm_b,
|
||||||
|
base_name_to_sets_of_related_ops, unmatchable_types_map)
|
||||||
|
nodes_and_names_to_instrument_inputs_a = []
|
||||||
|
nodes_and_names_to_instrument_inputs_b = []
|
||||||
|
nodes_and_names_to_instrument_outputs_a = []
|
||||||
|
nodes_and_names_to_instrument_outputs_b = []
|
||||||
|
for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
|
||||||
|
ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
|
||||||
|
ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
|
||||||
|
# Note: for matching inputs we use start_node, such as observing
|
||||||
|
# the input of linear in linear-relu
|
||||||
|
if should_log_inputs:
|
||||||
|
nodes_and_names_to_instrument_inputs_a.append(
|
||||||
|
(subgraph_a.start_node, match_name, ref_node_type_a))
|
||||||
|
nodes_and_names_to_instrument_inputs_b.append(
|
||||||
|
(subgraph_b.start_node, match_name, ref_node_type_b))
|
||||||
|
# Note: for matching activations we always use end_node,
|
||||||
|
# such as observing the output of relu in linear-relu
|
||||||
|
nodes_and_names_to_instrument_outputs_a.append(
|
||||||
|
(subgraph_a.end_node, match_name, ref_node_type_a))
|
||||||
|
nodes_and_names_to_instrument_outputs_b.append(
|
||||||
|
(subgraph_b.end_node, match_name, ref_node_type_b))
|
||||||
|
|
||||||
|
new_model_a = _add_loggers_one_model(
|
||||||
|
name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
|
||||||
|
nodes_and_names_to_instrument_outputs_a, logger_cls)
|
||||||
|
new_model_b = _add_loggers_one_model(
|
||||||
|
name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
|
||||||
|
nodes_and_names_to_instrument_outputs_b, logger_cls)
|
||||||
|
return (new_model_a, new_model_b)
|
||||||
|
|
||||||
|
|
||||||
|
def add_loggers(
|
||||||
|
name_a: str,
|
||||||
|
model_a: nn.Module,
|
||||||
|
name_b: str,
|
||||||
|
model_b: nn.Module,
|
||||||
|
logger_cls: Callable,
|
||||||
|
should_log_inputs : bool = False,
|
||||||
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||||
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||||
|
) -> Tuple[nn.Module, nn.Module]:
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
|
||||||
|
# TODO(future PR): expose these
|
||||||
|
skipped_module_names: List[str] = []
|
||||||
|
skipped_module_classes: List[Callable] = []
|
||||||
|
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
||||||
|
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
||||||
|
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
||||||
|
if hasattr(model_a, '_node_name_to_scope'):
|
||||||
|
gm_a._node_name_to_scope = model_a._node_name_to_scope
|
||||||
|
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
||||||
|
if hasattr(model_b, '_node_name_to_scope'):
|
||||||
|
gm_b._node_name_to_scope = model_b._node_name_to_scope
|
||||||
|
return _add_loggers_impl(
|
||||||
|
name_a, gm_a, name_b, gm_b, logger_cls,
|
||||||
|
should_log_inputs=should_log_inputs,
|
||||||
|
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
||||||
|
unmatchable_types_map=unmatchable_types_map)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_logger_info_one_model(
|
||||||
|
model: nn.Module,
|
||||||
|
results: NSResultsType,
|
||||||
|
logger_cls: Callable,
|
||||||
|
) -> None:
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
|
||||||
|
for gm_name, mod in model.named_modules():
|
||||||
|
# TODO(future PR): better check when scripted
|
||||||
|
is_logger = (
|
||||||
|
isinstance(mod, logger_cls) # type: ignore[arg-type]
|
||||||
|
or (
|
||||||
|
isinstance(mod, torch.jit.RecursiveScriptModule)
|
||||||
|
and mod.original_name == 'OutputLogger'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if is_logger:
|
||||||
|
key = mod.ref_name
|
||||||
|
if key not in results:
|
||||||
|
results[key] = {}
|
||||||
|
assert mod.model_name not in results[key], \
|
||||||
|
f"{mod.model_name} is already present in results"
|
||||||
|
if mod.results_type not in results[key]:
|
||||||
|
results[key][mod.results_type] = {}
|
||||||
|
if mod.model_name not in results[key][mod.results_type]:
|
||||||
|
results[key][mod.results_type][mod.model_name] = []
|
||||||
|
stats_to_use = mod.stats
|
||||||
|
if len(mod.stats_rnn) > 0:
|
||||||
|
stats_to_use = mod.stats_rnn
|
||||||
|
results[key][mod.results_type][mod.model_name].append({
|
||||||
|
'type': mod.results_type,
|
||||||
|
'values': stats_to_use,
|
||||||
|
'ref_node_name': mod.ref_node_name,
|
||||||
|
'ref_node_target_type': mod.ref_node_target_type,
|
||||||
|
'prev_node_name': mod.prev_node_name,
|
||||||
|
'prev_node_target_type': mod.prev_node_target_type,
|
||||||
|
'index_within_arg': mod.index_within_arg,
|
||||||
|
'index_of_arg': mod.index_of_arg,
|
||||||
|
'fqn': mod.fqn,
|
||||||
|
})
|
||||||
|
# ensure the list stays sorted
|
||||||
|
results[key][mod.results_type][mod.model_name].sort(
|
||||||
|
key=lambda res:
|
||||||
|
f"{res['index_of_arg']}:{res['index_within_arg']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(future PR): align on naming
|
||||||
|
# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
|
||||||
|
def extract_logger_info(
|
||||||
|
model_a: nn.Module,
|
||||||
|
model_b: nn.Module,
|
||||||
|
logger_cls: Callable,
|
||||||
|
model_name_to_use_for_layer_names: str,
|
||||||
|
) -> NSResultsType:
|
||||||
|
"""
|
||||||
|
Same thing as ns.extract_logger_info, but for models prepared with
|
||||||
|
this module.
|
||||||
|
|
||||||
|
TODO(future PR): real docblock
|
||||||
|
|
||||||
|
Output format: NSResultsType
|
||||||
|
"""
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
|
||||||
|
results: NSResultsType = {}
|
||||||
|
for model in (model_a, model_b):
|
||||||
|
_extract_logger_info_one_model(model, results, logger_cls)
|
||||||
|
# fill in missing fqn entries
|
||||||
|
maybe_add_missing_fqns(results)
|
||||||
|
# rekey on the name of model b
|
||||||
|
results = rekey_logger_info_on_node_name_of_model(
|
||||||
|
results, model_name_to_use_for_layer_names)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _add_shadow_loggers_impl(
|
||||||
|
name_a: str,
|
||||||
|
gm_a: GraphModule,
|
||||||
|
name_b: str,
|
||||||
|
gm_b: GraphModule,
|
||||||
|
logger_cls: Callable,
|
||||||
|
should_log_inputs: bool,
|
||||||
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||||
|
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||||
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||||
|
) -> nn.Module:
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
|
||||||
|
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
||||||
|
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
||||||
|
unmatchable_types_map)
|
||||||
|
gm_a_shadows_b = create_a_shadows_b(
|
||||||
|
name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
|
||||||
|
should_log_inputs=should_log_inputs,
|
||||||
|
node_type_to_io_type_map=node_type_to_io_type_map)
|
||||||
|
return gm_a_shadows_b
|
||||||
|
|
||||||
|
|
||||||
|
def add_shadow_loggers(
|
||||||
|
name_a: str,
|
||||||
|
model_a: nn.Module,
|
||||||
|
name_b: str,
|
||||||
|
model_b: nn.Module,
|
||||||
|
logger_cls: Callable,
|
||||||
|
should_log_inputs: bool = False,
|
||||||
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||||
|
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||||
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Same thing as add_loggers, but for an `a_shadows_b` model.
|
||||||
|
TODO(future PR): real docblock
|
||||||
|
"""
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
|
||||||
|
# TODO(future PR): expose these
|
||||||
|
skipped_module_names: List[str] = []
|
||||||
|
skipped_module_classes: List[Callable] = []
|
||||||
|
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
||||||
|
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
||||||
|
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
||||||
|
if hasattr(model_a, '_node_name_to_scope'):
|
||||||
|
gm_a._node_name_to_scope = model_a._node_name_to_scope
|
||||||
|
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
||||||
|
if hasattr(model_b, '_node_name_to_scope'):
|
||||||
|
gm_b._node_name_to_scope = model_b._node_name_to_scope
|
||||||
|
return _add_shadow_loggers_impl(
|
||||||
|
name_a, gm_a, name_b, gm_b, logger_cls,
|
||||||
|
should_log_inputs=should_log_inputs,
|
||||||
|
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
||||||
|
node_type_to_io_type_map=node_type_to_io_type_map,
|
||||||
|
unmatchable_types_map=unmatchable_types_map)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_shadow_logger_info(
|
||||||
|
model_a_shadows_b: nn.Module,
|
||||||
|
logger_cls: Callable,
|
||||||
|
model_name_to_use_for_layer_names: str,
|
||||||
|
) -> NSResultsType:
|
||||||
|
"""
|
||||||
|
Same thing as extract_logger_info, but for an `a_shadows_b` model.
|
||||||
|
TODO(future PR): real docblock
|
||||||
|
"""
|
||||||
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
|
||||||
|
results: NSResultsType = collections.defaultdict(dict)
|
||||||
|
_extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
|
||||||
|
# fill in missing fqn entries
|
||||||
|
maybe_add_missing_fqns(results)
|
||||||
|
# rekey on the name of model b
|
||||||
|
results = rekey_logger_info_on_node_name_of_model(
|
||||||
|
results, model_name_to_use_for_layer_names)
|
||||||
|
return dict(results)
|
||||||
|
|
||||||
|
|
||||||
|
def extend_logger_results_with_comparison(
|
||||||
|
results: NSResultsType,
|
||||||
|
model_name_1: str,
|
||||||
|
model_name_2: str,
|
||||||
|
comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||||
|
comparison_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Compares the logged values from `model_name_2` against the corresponding
|
||||||
|
values in `model_name_1`, using `comparison_fn`. Records the result
|
||||||
|
in `model_name_2`'s results under `comparison_name`.
|
||||||
|
"""
|
||||||
|
for _, results_type_to_results in results.items():
|
||||||
|
for _, model_name_to_results in results_type_to_results.items():
|
||||||
|
assert model_name_1 in model_name_to_results, \
|
||||||
|
f"{model_name_1} not found in results"
|
||||||
|
assert model_name_2 in model_name_to_results, \
|
||||||
|
f"{model_name_2} not found in results"
|
||||||
|
|
||||||
|
results_1 = model_name_to_results[model_name_1]
|
||||||
|
results_2 = model_name_to_results[model_name_2]
|
||||||
|
|
||||||
|
for result_2 in results_2:
|
||||||
|
index_within_arg_2 = result_2['index_within_arg']
|
||||||
|
index_of_arg_2 = result_2['index_of_arg']
|
||||||
|
# find corresponding result_1
|
||||||
|
result_1 = None
|
||||||
|
for cur_result_1 in results_1:
|
||||||
|
index_within_arg_1 = cur_result_1['index_within_arg']
|
||||||
|
index_of_arg_1 = cur_result_1['index_of_arg']
|
||||||
|
if (
|
||||||
|
(index_within_arg_1 == index_within_arg_2) and
|
||||||
|
(index_of_arg_1 == index_of_arg_2)
|
||||||
|
):
|
||||||
|
result_1 = cur_result_1
|
||||||
|
break
|
||||||
|
assert result_1 is not None
|
||||||
|
|
||||||
|
values_1 = result_1['values']
|
||||||
|
values_2 = result_2['values']
|
||||||
|
result_2[comparison_name] = []
|
||||||
|
for value_1, value_2 in zip(values_1, values_2):
|
||||||
|
comparison_result = comparison_fn(value_1, value_2)
|
||||||
|
result_2[comparison_name].append(comparison_result)
|
||||||
0
torch/ao/ns/fx/__init__.py
Normal file
0
torch/ao/ns/fx/__init__.py
Normal file
|
|
@ -19,7 +19,7 @@ from .ns_types import (
|
||||||
NSSubgraph,
|
NSSubgraph,
|
||||||
NSNodeTargetType,
|
NSNodeTargetType,
|
||||||
)
|
)
|
||||||
from torch.quantization.ns.mappings import (
|
from torch.ao.ns.fx.mappings import (
|
||||||
get_node_type_to_io_type_map,
|
get_node_type_to_io_type_map,
|
||||||
)
|
)
|
||||||
from torch.quantization.quantize import is_activation_post_process
|
from torch.quantization.quantize import is_activation_post_process
|
||||||
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
||||||
import torch.nn.quantized as nnq
|
import torch.nn.quantized as nnq
|
||||||
|
|
||||||
import torch.quantization
|
import torch.quantization
|
||||||
import torch.quantization._numeric_suite as ns
|
import torch.ao.ns._numeric_suite as ns
|
||||||
|
|
||||||
_supported_modules = {nn.Linear, nn.Conv2d}
|
_supported_modules = {nn.Linear, nn.Conv2d}
|
||||||
_supported_modules_quantized = {nnq.Linear, nnq.Conv2d}
|
_supported_modules_quantized = {nnq.Linear, nnq.Conv2d}
|
||||||
|
|
|
||||||
|
|
@ -1,486 +1,28 @@
|
||||||
import torch
|
# flake8: noqa: F401
|
||||||
import torch.nn as nn
|
r"""
|
||||||
import torch.nn.quantized as nnq
|
This file is in the process of migration to `torch/ao/quantization`, and
|
||||||
import torch.nn.quantized.dynamic as nnqd
|
is kept here for compatibility while the migration process is ongoing.
|
||||||
from torch.quantization import prepare
|
If you are adding a new entry/functionality, please, add it to the
|
||||||
from typing import Dict, List, Optional, Any, Union, Callable, Set
|
`torch/ao/ns/_numeric_suite.py`, while adding an import statement
|
||||||
|
here.
|
||||||
|
"""
|
||||||
|
|
||||||
from .quantization_mappings import (
|
from torch.ao.ns._numeric_suite import (
|
||||||
get_default_compare_output_module_list,
|
NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
|
||||||
|
_find_match,
|
||||||
|
compare_weights,
|
||||||
|
_get_logger_dict_helper,
|
||||||
|
get_logger_dict,
|
||||||
|
Logger,
|
||||||
|
ShadowLogger,
|
||||||
|
OutputLogger,
|
||||||
|
_convert_tuple_to_list,
|
||||||
|
_dequantize_tensor_list,
|
||||||
|
Shadow,
|
||||||
|
prepare_model_with_stubs,
|
||||||
|
_is_identical_module_type,
|
||||||
|
compare_model_stub,
|
||||||
|
get_matching_activations,
|
||||||
|
prepare_model_outputs,
|
||||||
|
compare_model_outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
|
|
||||||
nnqd.Linear,
|
|
||||||
nnq.Linear,
|
|
||||||
nnqd.LSTM,
|
|
||||||
nn.LSTM,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _find_match(
|
|
||||||
str_list: Union[Dict[str, Any], List[str]], key_str: str,
|
|
||||||
postfix: str,
|
|
||||||
) -> Optional[str]:
|
|
||||||
split_str = key_str.split(".")
|
|
||||||
if split_str[-1] == postfix:
|
|
||||||
match_string = "".join(key_str.split(".")[0:-1])
|
|
||||||
for s2 in str_list:
|
|
||||||
pattern1 = "".join(s2.split(".")[0:-1])
|
|
||||||
pattern2 = "".join(s2.split(".")[0:-2])
|
|
||||||
if match_string == pattern1:
|
|
||||||
return s2
|
|
||||||
if match_string == pattern2:
|
|
||||||
return s2
|
|
||||||
|
|
||||||
# For matching "fc.weight" and "fc._packed_params._packed_params"
|
|
||||||
if postfix == "_packed_params":
|
|
||||||
match_string = "".join(key_str.split(".")[0:-2])
|
|
||||||
if len(match_string) == 0:
|
|
||||||
return None
|
|
||||||
for s2 in str_list:
|
|
||||||
pattern1 = "".join(s2.split(".")[0:-1])
|
|
||||||
pattern2 = "".join(s2.split(".")[0:-2])
|
|
||||||
if match_string == pattern1:
|
|
||||||
return s2
|
|
||||||
if match_string == pattern2:
|
|
||||||
return s2
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def compare_weights(
|
|
||||||
float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
|
|
||||||
) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
||||||
r"""Compare the weights of the float module with its corresponding quantized
|
|
||||||
module. Return a dict with key corresponding to module names and each entry being
|
|
||||||
a dictionary with two keys 'float' and 'quantized', containing the float and
|
|
||||||
quantized weights. This dict can be used to compare and compute the quantization
|
|
||||||
error of the weights of float and quantized models.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict())
|
|
||||||
for key in wt_compare_dict:
|
|
||||||
print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))
|
|
||||||
|
|
||||||
Args:
|
|
||||||
float_dict: state dict of the float model
|
|
||||||
quantized_dict: state dict of the quantized model
|
|
||||||
|
|
||||||
Return:
|
|
||||||
weight_dict: dict with key corresponding to module names and each entry being
|
|
||||||
a dictionary with two keys 'float' and 'quantized', containing the float and
|
|
||||||
quantized weights
|
|
||||||
"""
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights")
|
|
||||||
weight_dict: Dict[str, Dict] = {}
|
|
||||||
for key in quantized_dict:
|
|
||||||
match_key = _find_match(float_dict, key, "weight")
|
|
||||||
if match_key is not None:
|
|
||||||
weight_dict[key] = {}
|
|
||||||
weight_dict[key]["float"] = float_dict[match_key]
|
|
||||||
weight_dict[key]["quantized"] = quantized_dict[key]
|
|
||||||
continue
|
|
||||||
|
|
||||||
# For matching "fc.weight" and "fc._packed_params._packed_params"
|
|
||||||
match_key = _find_match(float_dict, key, "_packed_params")
|
|
||||||
if match_key is not None:
|
|
||||||
weight_dict[key] = {}
|
|
||||||
weight_dict[key]["float"] = float_dict[match_key]
|
|
||||||
weight_dict[key]["quantized"] = quantized_dict[key][0]
|
|
||||||
|
|
||||||
# For LSTM
|
|
||||||
split_str = key.split(".")
|
|
||||||
if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
|
|
||||||
layer = split_str[-2]
|
|
||||||
module_name = ".".join(split_str[:-3])
|
|
||||||
float_weight_ih_key = module_name + ".weight_ih_l" + layer
|
|
||||||
float_weight_hh_key = module_name + ".weight_hh_l" + layer
|
|
||||||
if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
|
|
||||||
weight_dict[key] = {}
|
|
||||||
weight_dict[key]["float"] = float_dict[float_weight_ih_key]
|
|
||||||
weight_dict[key]["quantized"] = (
|
|
||||||
quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
|
|
||||||
)
|
|
||||||
weight_dict[key]["float"] = float_dict[float_weight_hh_key]
|
|
||||||
weight_dict[key]["quantized"] = (
|
|
||||||
quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
|
|
||||||
)
|
|
||||||
|
|
||||||
return weight_dict
|
|
||||||
|
|
||||||
|
|
||||||
def _get_logger_dict_helper(
|
|
||||||
mod: nn.Module, target_dict: Dict[str, Any],
|
|
||||||
prefix: str = "",
|
|
||||||
) -> None:
|
|
||||||
r"""This is the helper function for get_logger_dict
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mod: module we want to save all logger stats
|
|
||||||
prefix: prefix for the current module
|
|
||||||
target_dict: the dictionary used to save all logger stats
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_prefix(prefix):
|
|
||||||
return prefix if prefix == "" else prefix + "."
|
|
||||||
|
|
||||||
for name, child in mod.named_children():
|
|
||||||
if isinstance(child, Logger):
|
|
||||||
target_dict[get_prefix(prefix) + "stats"] = child.stats
|
|
||||||
break
|
|
||||||
|
|
||||||
for name, child in mod.named_children():
|
|
||||||
module_prefix = get_prefix(prefix) + name if prefix else name
|
|
||||||
_get_logger_dict_helper(child, target_dict, module_prefix)
|
|
||||||
|
|
||||||
|
|
||||||
def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
|
|
||||||
r"""Traverse the modules and save all logger stats into target dict.
|
|
||||||
This is mainly used for quantization accuracy debug.
|
|
||||||
|
|
||||||
Type of loggers supported:
|
|
||||||
ShadowLogger: used to log the outputs of the quantized module and its
|
|
||||||
matching float shadow module,
|
|
||||||
OutputLogger: used to log the outputs of the modules
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mod: module we want to save all logger stats
|
|
||||||
prefix: prefix for the current module
|
|
||||||
|
|
||||||
Return:
|
|
||||||
target_dict: the dictionary used to save all logger stats
|
|
||||||
"""
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict")
|
|
||||||
|
|
||||||
target_dict: Dict[str, Dict] = {}
|
|
||||||
_get_logger_dict_helper(mod, target_dict, prefix)
|
|
||||||
return target_dict
|
|
||||||
|
|
||||||
|
|
||||||
class Logger(nn.Module):
|
|
||||||
r"""Base class for stats logging
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(Logger, self).__init__()
|
|
||||||
self.stats = {}
|
|
||||||
# We only insert observer if the op is quantized with static quantization,
|
|
||||||
# which is identified by activation_observer.dtype == quint8. This is needed
|
|
||||||
# when attaching Logger as observer for FX mode
|
|
||||||
self.dtype = torch.quint8
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ShadowLogger(Logger):
|
|
||||||
r"""Class used in Shadow module to record the outputs of the original and
|
|
||||||
shadow modules.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(ShadowLogger, self).__init__()
|
|
||||||
self.stats["float"] = []
|
|
||||||
self.stats["quantized"] = []
|
|
||||||
|
|
||||||
def forward(self, x, y):
|
|
||||||
if len(x) > 1:
|
|
||||||
x = x[0]
|
|
||||||
if len(y) > 1:
|
|
||||||
y = y[0]
|
|
||||||
self.stats["quantized"].append(x.detach())
|
|
||||||
self.stats["float"].append(y.detach())
|
|
||||||
|
|
||||||
|
|
||||||
class OutputLogger(Logger):
|
|
||||||
r"""Class used to log the outputs of the module
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(OutputLogger, self).__init__()
|
|
||||||
self.stats["tensor_val"] = []
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
self.stats["tensor_val"].append(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_tuple_to_list(t: Any) -> Any:
|
|
||||||
return list(_convert_tuple_to_list(x) for x in t) if type(t) is tuple else t
|
|
||||||
|
|
||||||
|
|
||||||
def _dequantize_tensor_list(t: Any) -> Any:
|
|
||||||
return (
|
|
||||||
list(_dequantize_tensor_list(x) for x in t)
|
|
||||||
if type(t) is list
|
|
||||||
else t.dequantize()
|
|
||||||
if t.is_quantized
|
|
||||||
else t
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Shadow(nn.Module):
|
|
||||||
r"""Shadow module attaches the float module to its matching quantized module
|
|
||||||
as the shadow. Then it uses Logger module to process the outputs of both
|
|
||||||
modules.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q_module: module quantized from float_module that we want to shadow
|
|
||||||
float_module: float module used to shadow q_module
|
|
||||||
logger_cls: type of logger used to process the outputs of q_module and
|
|
||||||
float_module. ShadowLogger or custom loggers can be used.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, q_module, float_module, logger_cls):
|
|
||||||
super(Shadow, self).__init__()
|
|
||||||
self.orig_module = q_module
|
|
||||||
self.shadow_module = float_module
|
|
||||||
self.dequant = nnq.DeQuantize()
|
|
||||||
self.logger = logger_cls()
|
|
||||||
|
|
||||||
def forward(self, *x) -> torch.Tensor:
|
|
||||||
xl = _convert_tuple_to_list(x)
|
|
||||||
output = self.orig_module(*xl)
|
|
||||||
xl_float = _dequantize_tensor_list(xl)
|
|
||||||
shadow_output = self.shadow_module(*xl_float)
|
|
||||||
self.logger(output, shadow_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
||||||
output = self.orig_module.add(x, y)
|
|
||||||
x = x.dequantize()
|
|
||||||
y = y.dequantize()
|
|
||||||
shadow_output = self.shadow_module.add(x, y)
|
|
||||||
self.logger(output, shadow_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
|
|
||||||
output = self.orig_module.add_scalar(x, y)
|
|
||||||
x = x.dequantize()
|
|
||||||
shadow_output = self.shadow_module.add_scalar(x, y)
|
|
||||||
self.logger(output, shadow_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
||||||
output = self.orig_module.mul(x, y)
|
|
||||||
x = x.dequantize()
|
|
||||||
y = y.dequantize()
|
|
||||||
shadow_output = self.shadow_module.mul(x, y)
|
|
||||||
self.logger(output, shadow_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
|
|
||||||
output = self.orig_module.mul_scalar(x, y)
|
|
||||||
x = x.dequantize()
|
|
||||||
shadow_output = self.shadow_module.mul_scalar(x, y)
|
|
||||||
self.logger(output, shadow_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
|
|
||||||
output = self.orig_module.cat(x, dim)
|
|
||||||
x = [y.dequantize() for y in x]
|
|
||||||
shadow_output = self.shadow_module.cat(x, dim)
|
|
||||||
self.logger(output, shadow_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
||||||
output = self.orig_module.add_relu(x, y)
|
|
||||||
x = x.dequantize()
|
|
||||||
y = y.dequantize()
|
|
||||||
shadow_output = self.shadow_module.add_relu(x, y)
|
|
||||||
self.logger(output, shadow_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_model_with_stubs(
|
|
||||||
float_module: nn.Module, q_module: nn.Module,
|
|
||||||
module_swap_list: Set[type], logger_cls: Callable,
|
|
||||||
) -> None:
|
|
||||||
r"""Prepare the model by attaching the float module to its matching quantized
|
|
||||||
module as the shadow if the float module type is in module_swap_list.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
|
|
||||||
q_model(data)
|
|
||||||
ob_dict = get_logger_dict(q_model)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
float_module: float module used to generate the q_module
|
|
||||||
q_module: module quantized from float_module
|
|
||||||
module_swap_list: list of float module types to attach the shadow
|
|
||||||
logger_cls: type of logger to be used in shadow module to process the outputs of
|
|
||||||
quantized module and its float shadow module
|
|
||||||
"""
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_with_stubs")
|
|
||||||
|
|
||||||
float_module_children = {}
|
|
||||||
for name, mod in float_module.named_children():
|
|
||||||
float_module_children[name] = mod
|
|
||||||
|
|
||||||
reassign = {}
|
|
||||||
for name, mod in q_module.named_children():
|
|
||||||
|
|
||||||
if name not in float_module_children:
|
|
||||||
continue
|
|
||||||
|
|
||||||
float_mod = float_module_children[name]
|
|
||||||
|
|
||||||
if type(float_mod) not in module_swap_list:
|
|
||||||
prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls)
|
|
||||||
|
|
||||||
# Insert shadow module only if the module is not of the same type as
|
|
||||||
# the floating point module
|
|
||||||
if type(float_mod) in module_swap_list and not _is_identical_module_type(mod, float_mod):
|
|
||||||
reassign[name] = Shadow(mod, float_mod, logger_cls)
|
|
||||||
|
|
||||||
for key, value in reassign.items():
|
|
||||||
q_module._modules[key] = value
|
|
||||||
|
|
||||||
def _is_identical_module_type(mod1, mod2):
|
|
||||||
# Compare if two modules have the same dtype
|
|
||||||
mod1_module_types = [type(mod) for mod in mod1.modules()]
|
|
||||||
mod2_module_types = [type(mod) for mod in mod2.modules()]
|
|
||||||
return mod1_module_types == mod2_module_types
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def compare_model_stub(
|
|
||||||
float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
|
|
||||||
*data, logger_cls=ShadowLogger
|
|
||||||
) -> Dict[str, Dict]:
|
|
||||||
r"""Compare quantized module in a model with its floating point counterpart,
|
|
||||||
feeding both of them the same input. Return a dict with key corresponding to
|
|
||||||
module names and each entry being a dictionary with two keys 'float' and
|
|
||||||
'quantized', containing the output tensors of quantized and its matching
|
|
||||||
float shadow module. This dict can be used to compare and compute the module
|
|
||||||
level quantization error.
|
|
||||||
|
|
||||||
This function first call prepare_model_with_stubs() to swap the quantized
|
|
||||||
module that we want to compare with the Shadow module, which takes quantized
|
|
||||||
module, corresponding float module and logger as input, and creates a forward
|
|
||||||
path inside to make the float module to shadow quantized module sharing the
|
|
||||||
same input. The logger can be customizable, default logger is ShadowLogger
|
|
||||||
and it will save the outputs of the quantized module and float module that
|
|
||||||
can be used to compute the module level quantization error.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
|
|
||||||
ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
|
|
||||||
for key in ob_dict:
|
|
||||||
print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
|
|
||||||
|
|
||||||
Args:
|
|
||||||
float_model: float model used to generate the q_model
|
|
||||||
q_model: model quantized from float_model
|
|
||||||
module_swap_list: list of float module types at which shadow modules will
|
|
||||||
be attached.
|
|
||||||
data: input data used to run the prepared q_model
|
|
||||||
logger_cls: type of logger to be used in shadow module to process the outputs of
|
|
||||||
quantized module and its float shadow module
|
|
||||||
"""
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub")
|
|
||||||
prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls)
|
|
||||||
q_model(*data)
|
|
||||||
ob_dict = get_logger_dict(q_model)
|
|
||||||
return ob_dict
|
|
||||||
|
|
||||||
|
|
||||||
def get_matching_activations(
|
|
||||||
float_module: nn.Module, q_module: nn.Module,
|
|
||||||
) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
||||||
r"""Find the matching activation between float and quantized modules.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
float_module: float module used to generate the q_module
|
|
||||||
q_module: module quantized from float_module
|
|
||||||
|
|
||||||
Return:
|
|
||||||
act_dict: dict with key corresponding to quantized module names and each
|
|
||||||
entry being a dictionary with two keys 'float' and 'quantized', containing
|
|
||||||
the matching float and quantized activations
|
|
||||||
"""
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.get_matching_activations")
|
|
||||||
float_dict = get_logger_dict(float_module)
|
|
||||||
quantized_dict = get_logger_dict(q_module)
|
|
||||||
act_dict: Dict[str, Dict] = {}
|
|
||||||
for key in quantized_dict:
|
|
||||||
match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
|
|
||||||
if match_key is not None:
|
|
||||||
act_dict[key] = {}
|
|
||||||
act_dict[key]["float"] = float_dict[match_key]["tensor_val"]
|
|
||||||
act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"]
|
|
||||||
return act_dict
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_model_outputs(
|
|
||||||
float_module: nn.Module,
|
|
||||||
q_module: nn.Module,
|
|
||||||
logger_cls=OutputLogger,
|
|
||||||
allow_list=None
|
|
||||||
) -> None:
|
|
||||||
r"""Prepare the model by attaching the logger to both float module
|
|
||||||
and quantized module if they are in the allow_list.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
float_module: float module used to generate the q_module
|
|
||||||
q_module: module quantized from float_module
|
|
||||||
logger_cls: type of logger to be attached to float_module and q_module
|
|
||||||
allow_list: list of module types to attach logger
|
|
||||||
"""
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs")
|
|
||||||
if allow_list is None:
|
|
||||||
allow_list = get_default_compare_output_module_list()
|
|
||||||
|
|
||||||
qconfig_debug = torch.quantization.QConfig(activation=logger_cls, weight=None)
|
|
||||||
float_module.qconfig = qconfig_debug # type: ignore[assignment]
|
|
||||||
prepare(float_module, inplace=True, allow_list=allow_list)
|
|
||||||
q_module.qconfig = qconfig_debug # type: ignore[assignment]
|
|
||||||
prepare(
|
|
||||||
q_module,
|
|
||||||
inplace=True,
|
|
||||||
allow_list=allow_list,
|
|
||||||
observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def compare_model_outputs(
|
|
||||||
float_model: nn.Module,
|
|
||||||
q_model: nn.Module,
|
|
||||||
*data,
|
|
||||||
logger_cls=OutputLogger,
|
|
||||||
allow_list=None
|
|
||||||
) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
||||||
r"""Compare output activations between float and quantized models at
|
|
||||||
corresponding locations for the same input. Return a dict with key corresponding
|
|
||||||
to quantized module names and each entry being a dictionary with two keys
|
|
||||||
'float' and 'quantized', containing the activations of quantized model and
|
|
||||||
float model at matching locations. This dict can be used to compare and
|
|
||||||
compute the propagation quantization error.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
act_compare_dict = compare_model_outputs(float_model, qmodel, data)
|
|
||||||
for key in act_compare_dict:
|
|
||||||
print(key, compute_error(act_compare_dict[key]['float'], act_compare_dict[key]['quantized'].dequantize()))
|
|
||||||
|
|
||||||
Args:
|
|
||||||
float_model: float model used to generate the q_model
|
|
||||||
q_model: model quantized from float_model
|
|
||||||
data: input data used to run the prepared float_model and q_model
|
|
||||||
logger_cls: type of logger to be attached to float_module and q_module
|
|
||||||
allow_list: list of module types to attach logger
|
|
||||||
|
|
||||||
Return:
|
|
||||||
act_compare_dict: dict with key corresponding to quantized module names
|
|
||||||
and each entry being a dictionary with two keys 'float' and 'quantized',
|
|
||||||
containing the matching float and quantized activations
|
|
||||||
"""
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs")
|
|
||||||
if allow_list is None:
|
|
||||||
allow_list = get_default_compare_output_module_list()
|
|
||||||
prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
|
|
||||||
float_model(*data)
|
|
||||||
q_model(*data)
|
|
||||||
act_compare_dict = get_matching_activations(float_model, q_model)
|
|
||||||
return act_compare_dict
|
|
||||||
|
|
|
||||||
|
|
@ -1,513 +1,26 @@
|
||||||
import collections
|
# flake8: noqa: F401
|
||||||
|
r"""
|
||||||
|
This file is in the process of migration to `torch/ao/quantization`, and
|
||||||
|
is kept here for compatibility while the migration process is ongoing.
|
||||||
|
If you are adding a new entry/functionality, please, add it to the
|
||||||
|
`torch/ao/ns/_numeric_suite_fx.py`, while adding an import statement
|
||||||
|
here.
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
from torch.ao.ns._numeric_suite_fx import (
|
||||||
import torch.nn as nn
|
RNNReturnType,
|
||||||
import torch.quantization.quantize_fx as quantize_fx
|
OutputLogger,
|
||||||
from torch.fx import GraphModule
|
NSTracer,
|
||||||
from torch.fx.graph import Node
|
_extract_weights_one_model,
|
||||||
from torch.quantization.ns.mappings import (
|
_extract_weights_impl,
|
||||||
get_base_name_to_sets_of_related_ops,
|
extract_weights,
|
||||||
|
_add_loggers_one_model,
|
||||||
|
_add_loggers_impl,
|
||||||
|
add_loggers,
|
||||||
|
_extract_logger_info_one_model,
|
||||||
|
extract_logger_info,
|
||||||
|
_add_shadow_loggers_impl,
|
||||||
|
add_shadow_loggers,
|
||||||
|
extract_shadow_logger_info,
|
||||||
|
extend_logger_results_with_comparison,
|
||||||
)
|
)
|
||||||
from torch.quantization.ns.graph_matcher import (
|
|
||||||
get_matching_subgraph_pairs,
|
|
||||||
get_type_a_related_to_b,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .ns.weight_utils import (
|
|
||||||
extract_weight_from_node,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .ns.graph_passes import (
|
|
||||||
add_loggers_to_model,
|
|
||||||
create_a_shadows_b,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .ns.utils import (
|
|
||||||
rekey_logger_info_on_node_name_of_model,
|
|
||||||
maybe_add_missing_fqns,
|
|
||||||
get_target_type_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .ns.ns_types import (
|
|
||||||
NSSingleResultValuesType,
|
|
||||||
NSResultsType,
|
|
||||||
NSNodeTargetType,
|
|
||||||
)
|
|
||||||
|
|
||||||
from typing import Dict, Tuple, Callable, List, Optional, Set
|
|
||||||
|
|
||||||
RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
|
||||||
|
|
||||||
class OutputLogger(nn.Module):
|
|
||||||
stats: List[torch.Tensor]
|
|
||||||
stats_rnn: List[RNNReturnType]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ref_node_name: str,
|
|
||||||
prev_node_name: str,
|
|
||||||
model_name: str,
|
|
||||||
ref_name: str,
|
|
||||||
prev_node_target_type: str,
|
|
||||||
ref_node_target_type: str,
|
|
||||||
results_type: str,
|
|
||||||
index_within_arg: int,
|
|
||||||
index_of_arg: int,
|
|
||||||
fqn: Optional[str],
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.stats: List[torch.Tensor] = []
|
|
||||||
self.stats_rnn: List[RNNReturnType] = []
|
|
||||||
|
|
||||||
# name of the node which was responsible for adding this logger
|
|
||||||
# Note:
|
|
||||||
# - if we are logging node outputs, this is the same as prev_node_name
|
|
||||||
# - if we are logging node inputs, this is the name of the node
|
|
||||||
# whose input this logger is logging.
|
|
||||||
#
|
|
||||||
# example, where logger1 is logging input of op1 and logger2 is logging
|
|
||||||
# the output of op1:
|
|
||||||
#
|
|
||||||
# x1 -> logger1 -> op1 -> logger2 -> x2
|
|
||||||
#
|
|
||||||
# in this example,
|
|
||||||
# - logger1's prev_node_name is x1 and ref_node_name is op1
|
|
||||||
# - logger2's prev_node_name is op1 and ref_node_name is op1
|
|
||||||
self.ref_node_name = ref_node_name
|
|
||||||
# name of the node whose output this Logger is capturing
|
|
||||||
self.prev_node_name = prev_node_name
|
|
||||||
|
|
||||||
# name of the model from which the node originated from
|
|
||||||
self.model_name = model_name
|
|
||||||
# reference name, used to match loggers from separate models
|
|
||||||
# to each other
|
|
||||||
self.ref_name = ref_name
|
|
||||||
# type of the target of the node whose output this logger is logging
|
|
||||||
self.prev_node_target_type = prev_node_target_type
|
|
||||||
# type of the target of the node which was respondible for adding this
|
|
||||||
# logger
|
|
||||||
self.ref_node_target_type = ref_node_target_type
|
|
||||||
# what kind of values are inside of stats
|
|
||||||
self.results_type = results_type
|
|
||||||
# index of this node within the arg of the input/output node
|
|
||||||
# for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
|
|
||||||
self.index_within_arg = index_within_arg
|
|
||||||
# index of this node within the args of the input/output node
|
|
||||||
# for example, in add(x1, x2), x2 would have index_of_arg == 1
|
|
||||||
self.index_of_arg = index_of_arg
|
|
||||||
# fully qualified name
|
|
||||||
self.fqn = fqn
|
|
||||||
|
|
||||||
# Note: cannot annotate the type of x because TorchScript does not support
|
|
||||||
# the Union type.
|
|
||||||
def forward(self, x):
|
|
||||||
if isinstance(x, torch.Tensor):
|
|
||||||
self.stats.append(x.detach())
|
|
||||||
elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
|
|
||||||
new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
|
|
||||||
self.stats_rnn.append(new_res)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"""OutputLogger(ref_name={self.ref_name}, model_name={self.model_name},
|
|
||||||
prev_node_name={self.prev_node_name}, ref_node_name={self.ref_node_name},
|
|
||||||
ref_node_target_type={self.ref_node_target_type}
|
|
||||||
results_type={self.results_type}, index_within_arg={self.index_within_arg},
|
|
||||||
index_of_arg={self.index_of_arg}, fqn={self.fqn})"""
|
|
||||||
|
|
||||||
|
|
||||||
class NSTracer(quantize_fx.QuantizationTracer):
|
|
||||||
"""
|
|
||||||
Just like a regular tracer, but treats observers and fake_quantize
|
|
||||||
modules as leaf modules.
|
|
||||||
"""
|
|
||||||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
|
||||||
if isinstance(m, torch.quantization.ObserverBase):
|
|
||||||
return True
|
|
||||||
elif isinstance(m, torch.quantization.FakeQuantizeBase):
|
|
||||||
return True
|
|
||||||
return super().is_leaf_module(m, module_qualified_name)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_weights_one_model(
|
|
||||||
model_name: str,
|
|
||||||
model: GraphModule,
|
|
||||||
nodes_and_names_to_instrument: List[Tuple[Node, str]],
|
|
||||||
results: NSResultsType,
|
|
||||||
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
|
||||||
) -> None:
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
|
|
||||||
for node, ref_name in nodes_and_names_to_instrument:
|
|
||||||
res_type = NSSingleResultValuesType.WEIGHT.value
|
|
||||||
extracted_weight = extract_weight_from_node(
|
|
||||||
node, model, op_to_type_to_weight_extraction_fn)
|
|
||||||
if extracted_weight:
|
|
||||||
if ref_name not in results:
|
|
||||||
results[ref_name] = {res_type: {}}
|
|
||||||
results[ref_name][res_type][model_name] = [extracted_weight]
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_weights_impl(
|
|
||||||
model_name_a: str,
|
|
||||||
gm_a: GraphModule,
|
|
||||||
model_name_b: str,
|
|
||||||
gm_b: GraphModule,
|
|
||||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
||||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
||||||
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
|
||||||
) -> NSResultsType:
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
|
|
||||||
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
|
||||||
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
|
||||||
unmatchable_types_map)
|
|
||||||
|
|
||||||
# split the subgraph pairs into one data structure for each model
|
|
||||||
nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
|
|
||||||
nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
|
|
||||||
for match_name, match in matched_subgraph_pairs.items():
|
|
||||||
subgraph_a, subgraph_b = match
|
|
||||||
nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
|
|
||||||
nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
|
|
||||||
|
|
||||||
# populate the results, one model at a time
|
|
||||||
results: NSResultsType = {}
|
|
||||||
_extract_weights_one_model(
|
|
||||||
model_name_a, gm_a, nodes_and_names_to_instrument_a, results,
|
|
||||||
op_to_type_to_weight_extraction_fn)
|
|
||||||
_extract_weights_one_model(
|
|
||||||
model_name_b, gm_b, nodes_and_names_to_instrument_b, results,
|
|
||||||
op_to_type_to_weight_extraction_fn)
|
|
||||||
|
|
||||||
# fill in missing fqn entries
|
|
||||||
maybe_add_missing_fqns(results)
|
|
||||||
|
|
||||||
# rekey on names of nodes in gm_b
|
|
||||||
results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def extract_weights(
|
|
||||||
model_name_a: str,
|
|
||||||
model_a: nn.Module,
|
|
||||||
model_name_b: str,
|
|
||||||
model_b: nn.Module,
|
|
||||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
||||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
||||||
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
|
||||||
) -> NSResultsType:
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
|
|
||||||
if base_name_to_sets_of_related_ops is None:
|
|
||||||
base_name_to_sets_of_related_ops = \
|
|
||||||
get_base_name_to_sets_of_related_ops()
|
|
||||||
type_a_related_to_b = \
|
|
||||||
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
|
||||||
|
|
||||||
# TODO(future PR): expose these
|
|
||||||
skipped_module_names: List[str] = []
|
|
||||||
skipped_module_classes: List[Callable] = []
|
|
||||||
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
|
||||||
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
|
||||||
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
|
||||||
if hasattr(model_a, '_node_name_to_scope'):
|
|
||||||
gm_a._node_name_to_scope = model_a._node_name_to_scope
|
|
||||||
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
|
||||||
if hasattr(model_b, '_node_name_to_scope'):
|
|
||||||
gm_b._node_name_to_scope = model_b._node_name_to_scope
|
|
||||||
return _extract_weights_impl(
|
|
||||||
model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
|
|
||||||
unmatchable_types_map, op_to_type_to_weight_extraction_fn)
|
|
||||||
|
|
||||||
|
|
||||||
def _add_loggers_one_model(
|
|
||||||
model_name: str,
|
|
||||||
model: GraphModule,
|
|
||||||
nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]],
|
|
||||||
nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]],
|
|
||||||
logger_cls: Callable,
|
|
||||||
) -> nn.Module:
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
|
|
||||||
|
|
||||||
# TODO(future PR): do not observe nodes we do not care
|
|
||||||
# about (both fp32, denylist, etc)
|
|
||||||
node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
|
|
||||||
node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
|
|
||||||
for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs:
|
|
||||||
node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type)
|
|
||||||
for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs:
|
|
||||||
node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type)
|
|
||||||
|
|
||||||
model = add_loggers_to_model(
|
|
||||||
model, node_to_instrument_inputs_to_ref_name,
|
|
||||||
node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def _add_loggers_impl(
|
|
||||||
name_a: str,
|
|
||||||
gm_a: GraphModule,
|
|
||||||
name_b: str,
|
|
||||||
gm_b: GraphModule,
|
|
||||||
logger_cls: Callable,
|
|
||||||
should_log_inputs: bool,
|
|
||||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
||||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
||||||
) -> Tuple[nn.Module, nn.Module]:
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
|
|
||||||
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
|
||||||
gm_a, gm_b,
|
|
||||||
base_name_to_sets_of_related_ops, unmatchable_types_map)
|
|
||||||
nodes_and_names_to_instrument_inputs_a = []
|
|
||||||
nodes_and_names_to_instrument_inputs_b = []
|
|
||||||
nodes_and_names_to_instrument_outputs_a = []
|
|
||||||
nodes_and_names_to_instrument_outputs_b = []
|
|
||||||
for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
|
|
||||||
ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
|
|
||||||
ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
|
|
||||||
# Note: for matching inputs we use start_node, such as observing
|
|
||||||
# the input of linear in linear-relu
|
|
||||||
if should_log_inputs:
|
|
||||||
nodes_and_names_to_instrument_inputs_a.append(
|
|
||||||
(subgraph_a.start_node, match_name, ref_node_type_a))
|
|
||||||
nodes_and_names_to_instrument_inputs_b.append(
|
|
||||||
(subgraph_b.start_node, match_name, ref_node_type_b))
|
|
||||||
# Note: for matching activations we always use end_node,
|
|
||||||
# such as observing the output of relu in linear-relu
|
|
||||||
nodes_and_names_to_instrument_outputs_a.append(
|
|
||||||
(subgraph_a.end_node, match_name, ref_node_type_a))
|
|
||||||
nodes_and_names_to_instrument_outputs_b.append(
|
|
||||||
(subgraph_b.end_node, match_name, ref_node_type_b))
|
|
||||||
|
|
||||||
new_model_a = _add_loggers_one_model(
|
|
||||||
name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
|
|
||||||
nodes_and_names_to_instrument_outputs_a, logger_cls)
|
|
||||||
new_model_b = _add_loggers_one_model(
|
|
||||||
name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
|
|
||||||
nodes_and_names_to_instrument_outputs_b, logger_cls)
|
|
||||||
return (new_model_a, new_model_b)
|
|
||||||
|
|
||||||
|
|
||||||
def add_loggers(
|
|
||||||
name_a: str,
|
|
||||||
model_a: nn.Module,
|
|
||||||
name_b: str,
|
|
||||||
model_b: nn.Module,
|
|
||||||
logger_cls: Callable,
|
|
||||||
should_log_inputs : bool = False,
|
|
||||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
||||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
||||||
) -> Tuple[nn.Module, nn.Module]:
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
|
|
||||||
# TODO(future PR): expose these
|
|
||||||
skipped_module_names: List[str] = []
|
|
||||||
skipped_module_classes: List[Callable] = []
|
|
||||||
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
|
||||||
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
|
||||||
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
|
||||||
if hasattr(model_a, '_node_name_to_scope'):
|
|
||||||
gm_a._node_name_to_scope = model_a._node_name_to_scope
|
|
||||||
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
|
||||||
if hasattr(model_b, '_node_name_to_scope'):
|
|
||||||
gm_b._node_name_to_scope = model_b._node_name_to_scope
|
|
||||||
return _add_loggers_impl(
|
|
||||||
name_a, gm_a, name_b, gm_b, logger_cls,
|
|
||||||
should_log_inputs=should_log_inputs,
|
|
||||||
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
|
||||||
unmatchable_types_map=unmatchable_types_map)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_logger_info_one_model(
|
|
||||||
model: nn.Module,
|
|
||||||
results: NSResultsType,
|
|
||||||
logger_cls: Callable,
|
|
||||||
) -> None:
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
|
|
||||||
for gm_name, mod in model.named_modules():
|
|
||||||
# TODO(future PR): better check when scripted
|
|
||||||
is_logger = (
|
|
||||||
isinstance(mod, logger_cls) # type: ignore[arg-type]
|
|
||||||
or (
|
|
||||||
isinstance(mod, torch.jit.RecursiveScriptModule)
|
|
||||||
and mod.original_name == 'OutputLogger'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if is_logger:
|
|
||||||
key = mod.ref_name
|
|
||||||
if key not in results:
|
|
||||||
results[key] = {}
|
|
||||||
assert mod.model_name not in results[key], \
|
|
||||||
f"{mod.model_name} is already present in results"
|
|
||||||
if mod.results_type not in results[key]:
|
|
||||||
results[key][mod.results_type] = {}
|
|
||||||
if mod.model_name not in results[key][mod.results_type]:
|
|
||||||
results[key][mod.results_type][mod.model_name] = []
|
|
||||||
stats_to_use = mod.stats
|
|
||||||
if len(mod.stats_rnn) > 0:
|
|
||||||
stats_to_use = mod.stats_rnn
|
|
||||||
results[key][mod.results_type][mod.model_name].append({
|
|
||||||
'type': mod.results_type,
|
|
||||||
'values': stats_to_use,
|
|
||||||
'ref_node_name': mod.ref_node_name,
|
|
||||||
'ref_node_target_type': mod.ref_node_target_type,
|
|
||||||
'prev_node_name': mod.prev_node_name,
|
|
||||||
'prev_node_target_type': mod.prev_node_target_type,
|
|
||||||
'index_within_arg': mod.index_within_arg,
|
|
||||||
'index_of_arg': mod.index_of_arg,
|
|
||||||
'fqn': mod.fqn,
|
|
||||||
})
|
|
||||||
# ensure the list stays sorted
|
|
||||||
results[key][mod.results_type][mod.model_name].sort(
|
|
||||||
key=lambda res:
|
|
||||||
f"{res['index_of_arg']}:{res['index_within_arg']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(future PR): align on naming
|
|
||||||
# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
|
|
||||||
def extract_logger_info(
|
|
||||||
model_a: nn.Module,
|
|
||||||
model_b: nn.Module,
|
|
||||||
logger_cls: Callable,
|
|
||||||
model_name_to_use_for_layer_names: str,
|
|
||||||
) -> NSResultsType:
|
|
||||||
"""
|
|
||||||
Same thing as ns.extract_logger_info, but for models prepared with
|
|
||||||
this module.
|
|
||||||
|
|
||||||
TODO(future PR): real docblock
|
|
||||||
|
|
||||||
Output format: NSResultsType
|
|
||||||
"""
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
|
|
||||||
results: NSResultsType = {}
|
|
||||||
for model in (model_a, model_b):
|
|
||||||
_extract_logger_info_one_model(model, results, logger_cls)
|
|
||||||
# fill in missing fqn entries
|
|
||||||
maybe_add_missing_fqns(results)
|
|
||||||
# rekey on the name of model b
|
|
||||||
results = rekey_logger_info_on_node_name_of_model(
|
|
||||||
results, model_name_to_use_for_layer_names)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def _add_shadow_loggers_impl(
|
|
||||||
name_a: str,
|
|
||||||
gm_a: GraphModule,
|
|
||||||
name_b: str,
|
|
||||||
gm_b: GraphModule,
|
|
||||||
logger_cls: Callable,
|
|
||||||
should_log_inputs: bool,
|
|
||||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
||||||
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
||||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
||||||
) -> nn.Module:
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
|
|
||||||
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
|
||||||
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
|
||||||
unmatchable_types_map)
|
|
||||||
gm_a_shadows_b = create_a_shadows_b(
|
|
||||||
name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
|
|
||||||
should_log_inputs=should_log_inputs,
|
|
||||||
node_type_to_io_type_map=node_type_to_io_type_map)
|
|
||||||
return gm_a_shadows_b
|
|
||||||
|
|
||||||
|
|
||||||
def add_shadow_loggers(
|
|
||||||
name_a: str,
|
|
||||||
model_a: nn.Module,
|
|
||||||
name_b: str,
|
|
||||||
model_b: nn.Module,
|
|
||||||
logger_cls: Callable,
|
|
||||||
should_log_inputs: bool = False,
|
|
||||||
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
||||||
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
||||||
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
||||||
) -> nn.Module:
|
|
||||||
"""
|
|
||||||
Same thing as add_loggers, but for an `a_shadows_b` model.
|
|
||||||
TODO(future PR): real docblock
|
|
||||||
"""
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
|
|
||||||
# TODO(future PR): expose these
|
|
||||||
skipped_module_names: List[str] = []
|
|
||||||
skipped_module_classes: List[Callable] = []
|
|
||||||
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
|
||||||
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
|
||||||
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
|
||||||
if hasattr(model_a, '_node_name_to_scope'):
|
|
||||||
gm_a._node_name_to_scope = model_a._node_name_to_scope
|
|
||||||
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
|
||||||
if hasattr(model_b, '_node_name_to_scope'):
|
|
||||||
gm_b._node_name_to_scope = model_b._node_name_to_scope
|
|
||||||
return _add_shadow_loggers_impl(
|
|
||||||
name_a, gm_a, name_b, gm_b, logger_cls,
|
|
||||||
should_log_inputs=should_log_inputs,
|
|
||||||
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
|
||||||
node_type_to_io_type_map=node_type_to_io_type_map,
|
|
||||||
unmatchable_types_map=unmatchable_types_map)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_shadow_logger_info(
|
|
||||||
model_a_shadows_b: nn.Module,
|
|
||||||
logger_cls: Callable,
|
|
||||||
model_name_to_use_for_layer_names: str,
|
|
||||||
) -> NSResultsType:
|
|
||||||
"""
|
|
||||||
Same thing as extract_logger_info, but for an `a_shadows_b` model.
|
|
||||||
TODO(future PR): real docblock
|
|
||||||
"""
|
|
||||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
|
|
||||||
results: NSResultsType = collections.defaultdict(dict)
|
|
||||||
_extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
|
|
||||||
# fill in missing fqn entries
|
|
||||||
maybe_add_missing_fqns(results)
|
|
||||||
# rekey on the name of model b
|
|
||||||
results = rekey_logger_info_on_node_name_of_model(
|
|
||||||
results, model_name_to_use_for_layer_names)
|
|
||||||
return dict(results)
|
|
||||||
|
|
||||||
|
|
||||||
def extend_logger_results_with_comparison(
|
|
||||||
results: NSResultsType,
|
|
||||||
model_name_1: str,
|
|
||||||
model_name_2: str,
|
|
||||||
comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
||||||
comparison_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Compares the logged values from `model_name_2` against the corresponding
|
|
||||||
values in `model_name_1`, using `comparison_fn`. Records the result
|
|
||||||
in `model_name_2`'s results under `comparison_name`.
|
|
||||||
"""
|
|
||||||
for _, results_type_to_results in results.items():
|
|
||||||
for _, model_name_to_results in results_type_to_results.items():
|
|
||||||
assert model_name_1 in model_name_to_results, \
|
|
||||||
f"{model_name_1} not found in results"
|
|
||||||
assert model_name_2 in model_name_to_results, \
|
|
||||||
f"{model_name_2} not found in results"
|
|
||||||
|
|
||||||
results_1 = model_name_to_results[model_name_1]
|
|
||||||
results_2 = model_name_to_results[model_name_2]
|
|
||||||
|
|
||||||
for result_2 in results_2:
|
|
||||||
index_within_arg_2 = result_2['index_within_arg']
|
|
||||||
index_of_arg_2 = result_2['index_of_arg']
|
|
||||||
# find corresponding result_1
|
|
||||||
result_1 = None
|
|
||||||
for cur_result_1 in results_1:
|
|
||||||
index_within_arg_1 = cur_result_1['index_within_arg']
|
|
||||||
index_of_arg_1 = cur_result_1['index_of_arg']
|
|
||||||
if (
|
|
||||||
(index_within_arg_1 == index_within_arg_2) and
|
|
||||||
(index_of_arg_1 == index_of_arg_2)
|
|
||||||
):
|
|
||||||
result_1 = cur_result_1
|
|
||||||
break
|
|
||||||
assert result_1 is not None
|
|
||||||
|
|
||||||
values_1 = result_1['values']
|
|
||||||
values_2 = result_2['values']
|
|
||||||
result_2[comparison_name] = []
|
|
||||||
for value_1, value_2 in zip(values_1, values_2):
|
|
||||||
comparison_result = comparison_fn(value_1, value_2)
|
|
||||||
result_2[comparison_name].append(comparison_result)
|
|
||||||
|
|
|
||||||
|
|
@ -742,8 +742,8 @@ def get_layer_sqnr_dict(model_a: nn.Module, model_b: nn.Module, x: torch.Tensor)
|
||||||
model_b: A quantized model
|
model_b: A quantized model
|
||||||
x: Inputs to use during calibration
|
x: Inputs to use during calibration
|
||||||
"""
|
"""
|
||||||
import torch.quantization._numeric_suite_fx as ns
|
import torch.ao.ns._numeric_suite_fx as ns
|
||||||
from torch.quantization.ns.mappings import get_unmatchable_types_map
|
from torch.ao.ns.fx.mappings import get_unmatchable_types_map
|
||||||
|
|
||||||
unmatchable_types_map = get_unmatchable_types_map()
|
unmatchable_types_map = get_unmatchable_types_map()
|
||||||
unmatchable_types_map["funs_unmatchable"].add(torch.mul)
|
unmatchable_types_map["funs_unmatchable"].add(torch.mul)
|
||||||
|
|
@ -766,7 +766,7 @@ def get_layer_sqnr_dict(model_a: nn.Module, model_b: nn.Module, x: torch.Tensor)
|
||||||
ns.extend_logger_results_with_comparison(
|
ns.extend_logger_results_with_comparison(
|
||||||
activation_comparison_dict,
|
activation_comparison_dict,
|
||||||
'fp32', 'int8',
|
'fp32', 'int8',
|
||||||
torch.quantization.ns.utils.compute_sqnr, 'sqnr'
|
torch.ao.ns.fx.utils.compute_sqnr, 'sqnr'
|
||||||
)
|
)
|
||||||
|
|
||||||
# Construct a dictionary mapping layer names to the SQNR values
|
# Construct a dictionary mapping layer names to the SQNR values
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ try:
|
||||||
prepare_qat_fx,
|
prepare_qat_fx,
|
||||||
convert_fx,
|
convert_fx,
|
||||||
)
|
)
|
||||||
from torch.quantization.ns.ns_types import NSSingleResultValuesType, NSSubgraph
|
from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph
|
||||||
from torch.fx.graph import Node
|
from torch.fx.graph import Node
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
HAS_FX = True
|
HAS_FX = True
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user