pytorch/torch/quantization/_numeric_suite_fx.py
Vasiliy Kuznetsov 31fe1c1323 ns for fx: rekey results by model node names (#60305)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60305

Adjusts the NS for FX weight and activation extraction APIs
to require a model name, and rekeys the results of these APIs
to use the node names of the specified model as layer keys.

For example, before

```
// API call
results = ns.extract_logger_info(
  model_a, model_b, ns.OutputLogger)

// results
{'base_op_1_0': {'node_output':
  {'model_a': [{'ref_node_name': 'linear1', ...}]}}}
```

and after

```
// API call
results = ns.extract_logger_info(
  model_a, model_b, ns.OutputLogger, 'model_b_name')

// results
// note: instead of `base_op_1_0`, the layer is named `linear1`
{'linear1': {'node_output':
  {'model_a': [{'ref_node_name': 'linear1', ...}]}}}
```

Note: we cannot use these names while collecting data because
node names are not guaranteed to be consistent across graphs.
This is why we only rekey as the very last step.

Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_layer_names
```

Imported from OSS

Reviewed By: hx89

Differential Revision: D29243045

fbshipit-source-id: d39ecdfdd18b07291e3ecefed2ede287b100b7d0
2021-06-24 13:41:01 -07:00

428 lines
17 KiB
Python

import collections
import torch
import torch.nn as nn
import torch.quantization.quantize_fx as quantize_fx
from torch.fx import GraphModule
from torch.fx.graph import Node
from torch.quantization.ns.mappings import (
get_base_name_to_sets_of_related_ops,
)
from torch.quantization.ns.graph_matcher import (
get_matching_subgraph_pairs,
get_type_a_related_to_b,
)
from .ns.weight_utils import (
extract_weight_from_node,
)
from .ns.graph_passes import (
remove_observers_add_loggers,
create_a_shadows_b,
)
from .ns.utils import (
rekey_logger_info_on_node_name_of_model,
)
from .ns.ns_types import (
NSSingleResultValuesType,
NSResultsType,
NSNodeTargetType,
)
from typing import Dict, Tuple, Callable, List, Optional, Set
RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
class OutputLogger(nn.Module):
stats: List[torch.Tensor]
stats_rnn: List[RNNReturnType]
def __init__(
self,
ref_node_name: str,
prev_node_name: str,
model_name: str,
ref_name: str,
prev_node_target_type: str,
results_type: str,
index_within_arg: int,
index_of_arg: int,
):
super().__init__()
self.stats: List[torch.Tensor] = []
self.stats_rnn: List[RNNReturnType] = []
# name of the node which was responsible for adding this logger
# Note:
# - if we are logging node outputs, this is the same as prev_node_name
# - if we are logging node inputs, this is the name of the node
# whose input this logger is logging.
#
# example, where logger1 is logging input of op1 and logger2 is logging
# the output of op1:
#
# x1 -> logger1 -> op1 -> logger2 -> x2
#
# in this example,
# - logger1's prev_node_name is x1 and ref_node_name is op1
# - logger2's prev_node_name is op1 and ref_node_name is op1
self.ref_node_name = ref_node_name
# name of the node whose output this Logger is capturing
self.prev_node_name = prev_node_name
# name of the model from which the node originated from
self.model_name = model_name
# reference name, used to match loggers from separate models
# to each other
self.ref_name = ref_name
# type of the target of the node whose output this logger is logging
self.prev_node_target_type = prev_node_target_type
# what kind of values are inside of stats
self.results_type = results_type
# index of this node within the arg of the input/output node
# for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
self.index_within_arg = index_within_arg
# index of this node within the args of the input/output node
# for example, in add(x1, x2), x2 would have index_of_arg == 1
self.index_of_arg = index_of_arg
# Note: cannot annotate the type of x because TorchScript does not support
# the Union type.
def forward(self, x):
if isinstance(x, torch.Tensor):
self.stats.append(x.detach())
elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
self.stats_rnn.append(new_res)
return x
def __repr__(self):
return f"""OutputLogger(ref_name={self.ref_name}, model_name={self.model_name},
prev_node_name={self.prev_node_name}, ref_node_name={self.ref_node_name},
results_type={self.results_type}, index_within_arg={self.index_within_arg},
index_of_arg={self.index_of_arg})"""
class NSTracer(quantize_fx.QuantizationTracer):
"""
Just like a regular tracer, but treats observers and fake_quantize
modules as leaf modules.
"""
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
if isinstance(m, torch.quantization.ObserverBase):
return True
elif isinstance(m, torch.quantization.FakeQuantizeBase):
return True
return super().is_leaf_module(m, module_qualified_name)
def _extract_weights_one_model(
model_name: str,
model: GraphModule,
nodes_and_names_to_instrument: List[Tuple[Node, str]],
results: NSResultsType,
) -> None:
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
type_a_related_to_b = \
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
for node, ref_name in nodes_and_names_to_instrument:
res_type = NSSingleResultValuesType.WEIGHT.value
extracted_weight = \
extract_weight_from_node(node, model, type_a_related_to_b)
if extracted_weight:
if ref_name not in results:
results[ref_name] = {res_type: {}}
results[ref_name][res_type][model_name] = [extracted_weight]
def _extract_weights_impl(
model_name_a: str,
gm_a: GraphModule,
model_name_b: str,
gm_b: GraphModule,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> NSResultsType:
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
matched_subgraph_pairs = get_matching_subgraph_pairs(
gm_a, gm_b, base_name_to_sets_of_related_ops,
unmatchable_types_map)
# split the subgraph pairs into one data structure for each model
nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
for match_name, match in matched_subgraph_pairs.items():
subgraph_a, subgraph_b = match
nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
# populate the results, one model at a time
results: NSResultsType = {}
_extract_weights_one_model(
model_name_a, gm_a, nodes_and_names_to_instrument_a, results)
_extract_weights_one_model(
model_name_b, gm_b, nodes_and_names_to_instrument_b, results)
# rekey on names of nodes in gm_b
results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
return results
def extract_weights(
model_name_a: str,
model_a: nn.Module,
model_name_b: str,
model_b: nn.Module,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> NSResultsType:
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
type_a_related_to_b = \
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
# TODO(future PR): expose these
skipped_module_names: List[str] = []
skipped_module_classes: List[Callable] = []
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
return _extract_weights_impl(
model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
unmatchable_types_map)
def _add_loggers_one_model(
model_name: str,
model: GraphModule,
nodes_and_names_to_instrument_inputs: List[Tuple[Node, str]],
nodes_and_names_to_instrument_outputs: List[Tuple[Node, str]],
logger_cls: Callable,
) -> nn.Module:
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
# TODO(future PR): do not observe nodes we do not care
# about (both fp32, denylist, etc)
node_to_instrument_inputs_to_ref_name: Dict[Node, str] = {}
node_to_instrument_outputs_to_ref_name: Dict[Node, str] = {}
for node, ref_name in nodes_and_names_to_instrument_inputs:
node_to_instrument_inputs_to_ref_name[node] = ref_name
for node, ref_name in nodes_and_names_to_instrument_outputs:
node_to_instrument_outputs_to_ref_name[node] = ref_name
model = remove_observers_add_loggers(
model, node_to_instrument_inputs_to_ref_name,
node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
return model
def _add_loggers_impl(
name_a: str,
gm_a: GraphModule,
name_b: str,
gm_b: GraphModule,
logger_cls: Callable,
should_log_inputs: bool,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> Tuple[nn.Module, nn.Module]:
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
matched_subgraph_pairs = get_matching_subgraph_pairs(
gm_a, gm_b,
base_name_to_sets_of_related_ops, unmatchable_types_map)
nodes_and_names_to_instrument_inputs_a = []
nodes_and_names_to_instrument_inputs_b = []
nodes_and_names_to_instrument_outputs_a = []
nodes_and_names_to_instrument_outputs_b = []
for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
# Note: for matching inputs we use start_node, such as observing
# the input of linear in linear-relu
if should_log_inputs:
nodes_and_names_to_instrument_inputs_a.append((subgraph_a.start_node, match_name))
nodes_and_names_to_instrument_inputs_b.append((subgraph_b.start_node, match_name))
# Note: for matching activations we always use end_node,
# such as observing the output of relu in linear-relu
nodes_and_names_to_instrument_outputs_a.append((subgraph_a.end_node, match_name))
nodes_and_names_to_instrument_outputs_b.append((subgraph_b.end_node, match_name))
new_model_a = _add_loggers_one_model(
name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
nodes_and_names_to_instrument_outputs_a, logger_cls)
new_model_b = _add_loggers_one_model(
name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
nodes_and_names_to_instrument_outputs_b, logger_cls)
return (new_model_a, new_model_b)
def add_loggers(
name_a: str,
model_a: nn.Module,
name_b: str,
model_b: nn.Module,
logger_cls: Callable,
should_log_inputs : bool = False,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> Tuple[nn.Module, nn.Module]:
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
# TODO(future PR): expose these
skipped_module_names: List[str] = []
skipped_module_classes: List[Callable] = []
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
return _add_loggers_impl(
name_a, gm_a, name_b, gm_b, logger_cls,
should_log_inputs=should_log_inputs,
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
unmatchable_types_map=unmatchable_types_map)
def _extract_logger_info_one_model(
model: nn.Module,
results: NSResultsType,
logger_cls: Callable,
) -> None:
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
for gm_name, mod in model.named_modules():
# TODO(future PR): better check when scripted
is_logger = (
isinstance(mod, logger_cls) # type: ignore[arg-type]
or (
isinstance(mod, torch.jit.RecursiveScriptModule)
and mod.original_name == 'OutputLogger'
)
)
if is_logger:
key = mod.ref_name
if key not in results:
results[key] = {}
assert mod.model_name not in results[key], \
f"{mod.model_name} is already present in results"
if mod.results_type not in results[key]:
results[key][mod.results_type] = {}
if mod.model_name not in results[key][mod.results_type]:
results[key][mod.results_type][mod.model_name] = []
stats_to_use = mod.stats
if len(mod.stats_rnn) > 0:
stats_to_use = mod.stats_rnn
results[key][mod.results_type][mod.model_name].append({
'type': mod.results_type,
'values': stats_to_use,
'ref_node_name': mod.ref_node_name,
'prev_node_name': mod.prev_node_name,
'prev_node_target_type': mod.prev_node_target_type,
'index_within_arg': mod.index_within_arg,
'index_of_arg': mod.index_of_arg,
})
# ensure the list stays sorted
results[key][mod.results_type][mod.model_name].sort(
key=lambda res:
f"{res['index_of_arg']}:{res['index_within_arg']}"
)
# TODO(future PR): align on naming
# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
def extract_logger_info(
model_a: nn.Module,
model_b: nn.Module,
logger_cls: Callable,
model_name_to_use_for_layer_names: str,
) -> NSResultsType:
"""
Same thing as ns.extract_logger_info, but for models prepared with
this module.
TODO(future PR): real docblock
Output format: NSResultsType
"""
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
results: NSResultsType = {}
for model in (model_a, model_b):
_extract_logger_info_one_model(model, results, logger_cls)
# rekey on the name of model b
results = rekey_logger_info_on_node_name_of_model(
results, model_name_to_use_for_layer_names)
return results
def _add_shadow_loggers_impl(
name_a: str,
gm_a: GraphModule,
name_b: str,
gm_b: GraphModule,
logger_cls: Callable,
should_log_inputs: bool,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> nn.Module:
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
matched_subgraph_pairs = get_matching_subgraph_pairs(
gm_a, gm_b, base_name_to_sets_of_related_ops,
unmatchable_types_map)
gm_a_shadows_b = create_a_shadows_b(
name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
should_log_inputs=should_log_inputs,
node_type_to_io_type_map=node_type_to_io_type_map)
return gm_a_shadows_b
def add_shadow_loggers(
name_a: str,
model_a: nn.Module,
name_b: str,
model_b: nn.Module,
logger_cls: Callable,
should_log_inputs: bool = False,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> nn.Module:
"""
Same thing as add_loggers, but for an `a_shadows_b` model.
TODO(future PR): real docblock
"""
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
# TODO(future PR): expose these
skipped_module_names: List[str] = []
skipped_module_classes: List[Callable] = []
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
return _add_shadow_loggers_impl(
name_a, gm_a, name_b, gm_b, logger_cls,
should_log_inputs=should_log_inputs,
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
node_type_to_io_type_map=node_type_to_io_type_map,
unmatchable_types_map=unmatchable_types_map)
def extract_shadow_logger_info(
model_a_shadows_b: nn.Module,
logger_cls: Callable,
model_name_to_use_for_layer_names: str,
) -> NSResultsType:
"""
Same thing as extract_logger_info, but for an `a_shadows_b` model.
TODO(future PR): real docblock
"""
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
results: NSResultsType = collections.defaultdict(dict)
_extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
# rekey on the name of model b
results = rekey_logger_info_on_node_name_of_model(
results, model_name_to_use_for_layer_names)
return dict(results)