mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60305 Adjusts the NS for FX weight and activation extraction APIs to require a model name, and rekeys the results of these APIs to use the node names of the specified model as layer keys. For example, before ``` // API call results = ns.extract_logger_info( model_a, model_b, ns.OutputLogger) // results {'base_op_1_0': {'node_output': {'model_a': [{'ref_node_name': 'linear1', ...}]}}} ``` and after ``` // API call results = ns.extract_logger_info( model_a, model_b, ns.OutputLogger, 'model_b_name') // results // note: instead of `base_op_1_0`, the layer is named `linear1` {'linear1': {'node_output': {'model_a': [{'ref_node_name': 'linear1', ...}]}}} ``` Note: we cannot use these names while collecting data because node names are not guaranteed to be consistent across graphs. This is why we only rekey as the very last step. Test Plan: ``` python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_layer_names ``` Imported from OSS Reviewed By: hx89 Differential Revision: D29243045 fbshipit-source-id: d39ecdfdd18b07291e3ecefed2ede287b100b7d0
428 lines
17 KiB
Python
428 lines
17 KiB
Python
import collections
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.quantization.quantize_fx as quantize_fx
|
|
from torch.fx import GraphModule
|
|
from torch.fx.graph import Node
|
|
from torch.quantization.ns.mappings import (
|
|
get_base_name_to_sets_of_related_ops,
|
|
)
|
|
from torch.quantization.ns.graph_matcher import (
|
|
get_matching_subgraph_pairs,
|
|
get_type_a_related_to_b,
|
|
)
|
|
|
|
from .ns.weight_utils import (
|
|
extract_weight_from_node,
|
|
)
|
|
|
|
from .ns.graph_passes import (
|
|
remove_observers_add_loggers,
|
|
create_a_shadows_b,
|
|
)
|
|
|
|
from .ns.utils import (
|
|
rekey_logger_info_on_node_name_of_model,
|
|
)
|
|
|
|
from .ns.ns_types import (
|
|
NSSingleResultValuesType,
|
|
NSResultsType,
|
|
NSNodeTargetType,
|
|
)
|
|
|
|
from typing import Dict, Tuple, Callable, List, Optional, Set
|
|
|
|
RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
|
|
|
class OutputLogger(nn.Module):
|
|
stats: List[torch.Tensor]
|
|
stats_rnn: List[RNNReturnType]
|
|
|
|
def __init__(
|
|
self,
|
|
ref_node_name: str,
|
|
prev_node_name: str,
|
|
model_name: str,
|
|
ref_name: str,
|
|
prev_node_target_type: str,
|
|
results_type: str,
|
|
index_within_arg: int,
|
|
index_of_arg: int,
|
|
):
|
|
super().__init__()
|
|
self.stats: List[torch.Tensor] = []
|
|
self.stats_rnn: List[RNNReturnType] = []
|
|
|
|
# name of the node which was responsible for adding this logger
|
|
# Note:
|
|
# - if we are logging node outputs, this is the same as prev_node_name
|
|
# - if we are logging node inputs, this is the name of the node
|
|
# whose input this logger is logging.
|
|
#
|
|
# example, where logger1 is logging input of op1 and logger2 is logging
|
|
# the output of op1:
|
|
#
|
|
# x1 -> logger1 -> op1 -> logger2 -> x2
|
|
#
|
|
# in this example,
|
|
# - logger1's prev_node_name is x1 and ref_node_name is op1
|
|
# - logger2's prev_node_name is op1 and ref_node_name is op1
|
|
self.ref_node_name = ref_node_name
|
|
# name of the node whose output this Logger is capturing
|
|
self.prev_node_name = prev_node_name
|
|
|
|
# name of the model from which the node originated from
|
|
self.model_name = model_name
|
|
# reference name, used to match loggers from separate models
|
|
# to each other
|
|
self.ref_name = ref_name
|
|
# type of the target of the node whose output this logger is logging
|
|
self.prev_node_target_type = prev_node_target_type
|
|
# what kind of values are inside of stats
|
|
self.results_type = results_type
|
|
# index of this node within the arg of the input/output node
|
|
# for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
|
|
self.index_within_arg = index_within_arg
|
|
# index of this node within the args of the input/output node
|
|
# for example, in add(x1, x2), x2 would have index_of_arg == 1
|
|
self.index_of_arg = index_of_arg
|
|
|
|
# Note: cannot annotate the type of x because TorchScript does not support
|
|
# the Union type.
|
|
def forward(self, x):
|
|
if isinstance(x, torch.Tensor):
|
|
self.stats.append(x.detach())
|
|
elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
|
|
new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
|
|
self.stats_rnn.append(new_res)
|
|
return x
|
|
|
|
def __repr__(self):
|
|
return f"""OutputLogger(ref_name={self.ref_name}, model_name={self.model_name},
|
|
prev_node_name={self.prev_node_name}, ref_node_name={self.ref_node_name},
|
|
results_type={self.results_type}, index_within_arg={self.index_within_arg},
|
|
index_of_arg={self.index_of_arg})"""
|
|
|
|
|
|
class NSTracer(quantize_fx.QuantizationTracer):
|
|
"""
|
|
Just like a regular tracer, but treats observers and fake_quantize
|
|
modules as leaf modules.
|
|
"""
|
|
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
|
if isinstance(m, torch.quantization.ObserverBase):
|
|
return True
|
|
elif isinstance(m, torch.quantization.FakeQuantizeBase):
|
|
return True
|
|
return super().is_leaf_module(m, module_qualified_name)
|
|
|
|
|
|
def _extract_weights_one_model(
|
|
model_name: str,
|
|
model: GraphModule,
|
|
nodes_and_names_to_instrument: List[Tuple[Node, str]],
|
|
results: NSResultsType,
|
|
) -> None:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
|
|
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
|
|
type_a_related_to_b = \
|
|
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
|
|
|
for node, ref_name in nodes_and_names_to_instrument:
|
|
res_type = NSSingleResultValuesType.WEIGHT.value
|
|
extracted_weight = \
|
|
extract_weight_from_node(node, model, type_a_related_to_b)
|
|
if extracted_weight:
|
|
if ref_name not in results:
|
|
results[ref_name] = {res_type: {}}
|
|
results[ref_name][res_type][model_name] = [extracted_weight]
|
|
|
|
|
|
def _extract_weights_impl(
|
|
model_name_a: str,
|
|
gm_a: GraphModule,
|
|
model_name_b: str,
|
|
gm_b: GraphModule,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> NSResultsType:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
|
|
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
|
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
|
unmatchable_types_map)
|
|
|
|
# split the subgraph pairs into one data structure for each model
|
|
nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
|
|
nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
|
|
for match_name, match in matched_subgraph_pairs.items():
|
|
subgraph_a, subgraph_b = match
|
|
nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
|
|
nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
|
|
|
|
# populate the results, one model at a time
|
|
results: NSResultsType = {}
|
|
_extract_weights_one_model(
|
|
model_name_a, gm_a, nodes_and_names_to_instrument_a, results)
|
|
_extract_weights_one_model(
|
|
model_name_b, gm_b, nodes_and_names_to_instrument_b, results)
|
|
|
|
# rekey on names of nodes in gm_b
|
|
results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
|
|
|
|
return results
|
|
|
|
|
|
def extract_weights(
|
|
model_name_a: str,
|
|
model_a: nn.Module,
|
|
model_name_b: str,
|
|
model_b: nn.Module,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> NSResultsType:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
|
|
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
|
|
type_a_related_to_b = \
|
|
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
|
|
|
# TODO(future PR): expose these
|
|
skipped_module_names: List[str] = []
|
|
skipped_module_classes: List[Callable] = []
|
|
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
|
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
|
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
|
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
|
return _extract_weights_impl(
|
|
model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
|
|
unmatchable_types_map)
|
|
|
|
|
|
def _add_loggers_one_model(
|
|
model_name: str,
|
|
model: GraphModule,
|
|
nodes_and_names_to_instrument_inputs: List[Tuple[Node, str]],
|
|
nodes_and_names_to_instrument_outputs: List[Tuple[Node, str]],
|
|
logger_cls: Callable,
|
|
) -> nn.Module:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
|
|
|
|
# TODO(future PR): do not observe nodes we do not care
|
|
# about (both fp32, denylist, etc)
|
|
node_to_instrument_inputs_to_ref_name: Dict[Node, str] = {}
|
|
node_to_instrument_outputs_to_ref_name: Dict[Node, str] = {}
|
|
for node, ref_name in nodes_and_names_to_instrument_inputs:
|
|
node_to_instrument_inputs_to_ref_name[node] = ref_name
|
|
for node, ref_name in nodes_and_names_to_instrument_outputs:
|
|
node_to_instrument_outputs_to_ref_name[node] = ref_name
|
|
|
|
model = remove_observers_add_loggers(
|
|
model, node_to_instrument_inputs_to_ref_name,
|
|
node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
|
|
return model
|
|
|
|
|
|
def _add_loggers_impl(
|
|
name_a: str,
|
|
gm_a: GraphModule,
|
|
name_b: str,
|
|
gm_b: GraphModule,
|
|
logger_cls: Callable,
|
|
should_log_inputs: bool,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> Tuple[nn.Module, nn.Module]:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
|
|
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
|
gm_a, gm_b,
|
|
base_name_to_sets_of_related_ops, unmatchable_types_map)
|
|
nodes_and_names_to_instrument_inputs_a = []
|
|
nodes_and_names_to_instrument_inputs_b = []
|
|
nodes_and_names_to_instrument_outputs_a = []
|
|
nodes_and_names_to_instrument_outputs_b = []
|
|
for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
|
|
# Note: for matching inputs we use start_node, such as observing
|
|
# the input of linear in linear-relu
|
|
if should_log_inputs:
|
|
nodes_and_names_to_instrument_inputs_a.append((subgraph_a.start_node, match_name))
|
|
nodes_and_names_to_instrument_inputs_b.append((subgraph_b.start_node, match_name))
|
|
# Note: for matching activations we always use end_node,
|
|
# such as observing the output of relu in linear-relu
|
|
nodes_and_names_to_instrument_outputs_a.append((subgraph_a.end_node, match_name))
|
|
nodes_and_names_to_instrument_outputs_b.append((subgraph_b.end_node, match_name))
|
|
|
|
new_model_a = _add_loggers_one_model(
|
|
name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
|
|
nodes_and_names_to_instrument_outputs_a, logger_cls)
|
|
new_model_b = _add_loggers_one_model(
|
|
name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
|
|
nodes_and_names_to_instrument_outputs_b, logger_cls)
|
|
return (new_model_a, new_model_b)
|
|
|
|
|
|
def add_loggers(
|
|
name_a: str,
|
|
model_a: nn.Module,
|
|
name_b: str,
|
|
model_b: nn.Module,
|
|
logger_cls: Callable,
|
|
should_log_inputs : bool = False,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> Tuple[nn.Module, nn.Module]:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
|
|
# TODO(future PR): expose these
|
|
skipped_module_names: List[str] = []
|
|
skipped_module_classes: List[Callable] = []
|
|
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
|
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
|
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
|
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
|
return _add_loggers_impl(
|
|
name_a, gm_a, name_b, gm_b, logger_cls,
|
|
should_log_inputs=should_log_inputs,
|
|
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
|
unmatchable_types_map=unmatchable_types_map)
|
|
|
|
|
|
def _extract_logger_info_one_model(
|
|
model: nn.Module,
|
|
results: NSResultsType,
|
|
logger_cls: Callable,
|
|
) -> None:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
|
|
for gm_name, mod in model.named_modules():
|
|
# TODO(future PR): better check when scripted
|
|
is_logger = (
|
|
isinstance(mod, logger_cls) # type: ignore[arg-type]
|
|
or (
|
|
isinstance(mod, torch.jit.RecursiveScriptModule)
|
|
and mod.original_name == 'OutputLogger'
|
|
)
|
|
)
|
|
if is_logger:
|
|
key = mod.ref_name
|
|
if key not in results:
|
|
results[key] = {}
|
|
assert mod.model_name not in results[key], \
|
|
f"{mod.model_name} is already present in results"
|
|
if mod.results_type not in results[key]:
|
|
results[key][mod.results_type] = {}
|
|
if mod.model_name not in results[key][mod.results_type]:
|
|
results[key][mod.results_type][mod.model_name] = []
|
|
stats_to_use = mod.stats
|
|
if len(mod.stats_rnn) > 0:
|
|
stats_to_use = mod.stats_rnn
|
|
results[key][mod.results_type][mod.model_name].append({
|
|
'type': mod.results_type,
|
|
'values': stats_to_use,
|
|
'ref_node_name': mod.ref_node_name,
|
|
'prev_node_name': mod.prev_node_name,
|
|
'prev_node_target_type': mod.prev_node_target_type,
|
|
'index_within_arg': mod.index_within_arg,
|
|
'index_of_arg': mod.index_of_arg,
|
|
})
|
|
# ensure the list stays sorted
|
|
results[key][mod.results_type][mod.model_name].sort(
|
|
key=lambda res:
|
|
f"{res['index_of_arg']}:{res['index_within_arg']}"
|
|
)
|
|
|
|
|
|
# TODO(future PR): align on naming
|
|
# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
|
|
def extract_logger_info(
|
|
model_a: nn.Module,
|
|
model_b: nn.Module,
|
|
logger_cls: Callable,
|
|
model_name_to_use_for_layer_names: str,
|
|
) -> NSResultsType:
|
|
"""
|
|
Same thing as ns.extract_logger_info, but for models prepared with
|
|
this module.
|
|
|
|
TODO(future PR): real docblock
|
|
|
|
Output format: NSResultsType
|
|
"""
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
|
|
results: NSResultsType = {}
|
|
for model in (model_a, model_b):
|
|
_extract_logger_info_one_model(model, results, logger_cls)
|
|
# rekey on the name of model b
|
|
results = rekey_logger_info_on_node_name_of_model(
|
|
results, model_name_to_use_for_layer_names)
|
|
return results
|
|
|
|
|
|
def _add_shadow_loggers_impl(
|
|
name_a: str,
|
|
gm_a: GraphModule,
|
|
name_b: str,
|
|
gm_b: GraphModule,
|
|
logger_cls: Callable,
|
|
should_log_inputs: bool,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> nn.Module:
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
|
|
matched_subgraph_pairs = get_matching_subgraph_pairs(
|
|
gm_a, gm_b, base_name_to_sets_of_related_ops,
|
|
unmatchable_types_map)
|
|
gm_a_shadows_b = create_a_shadows_b(
|
|
name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
|
|
should_log_inputs=should_log_inputs,
|
|
node_type_to_io_type_map=node_type_to_io_type_map)
|
|
return gm_a_shadows_b
|
|
|
|
|
|
def add_shadow_loggers(
|
|
name_a: str,
|
|
model_a: nn.Module,
|
|
name_b: str,
|
|
model_b: nn.Module,
|
|
logger_cls: Callable,
|
|
should_log_inputs: bool = False,
|
|
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
|
|
) -> nn.Module:
|
|
"""
|
|
Same thing as add_loggers, but for an `a_shadows_b` model.
|
|
TODO(future PR): real docblock
|
|
"""
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
|
|
# TODO(future PR): expose these
|
|
skipped_module_names: List[str] = []
|
|
skipped_module_classes: List[Callable] = []
|
|
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
|
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
|
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
|
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
|
return _add_shadow_loggers_impl(
|
|
name_a, gm_a, name_b, gm_b, logger_cls,
|
|
should_log_inputs=should_log_inputs,
|
|
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
|
node_type_to_io_type_map=node_type_to_io_type_map,
|
|
unmatchable_types_map=unmatchable_types_map)
|
|
|
|
|
|
def extract_shadow_logger_info(
|
|
model_a_shadows_b: nn.Module,
|
|
logger_cls: Callable,
|
|
model_name_to_use_for_layer_names: str,
|
|
) -> NSResultsType:
|
|
"""
|
|
Same thing as extract_logger_info, but for an `a_shadows_b` model.
|
|
TODO(future PR): real docblock
|
|
"""
|
|
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
|
|
results: NSResultsType = collections.defaultdict(dict)
|
|
_extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
|
|
# rekey on the name of model b
|
|
results = rekey_logger_info_on_node_name_of_model(
|
|
results, model_name_to_use_for_layer_names)
|
|
return dict(results)
|