mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55505 This necessary to add support in NS for QAT modules, to avoid duplicating logic between NSTracer and QuantizationTracer. The eng work to expose the custom module and class names to the user will be in a future PR. Test Plan: ``` python test/test_quantization.py TestFXNumericSuiteCoreAPIs python test/test_quantization.py TestFXNumericSuiteCoreAPIsModels ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D27650407 fbshipit-source-id: 431f47c5353b41c11371c5efa79657bfd085459a
364 lines
13 KiB
Python
364 lines
13 KiB
Python
import collections
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.quantization.quantize_fx as quantize_fx
|
|
from torch.fx import GraphModule
|
|
from torch.fx.graph import Node
|
|
from torch.quantization.ns.graph_matcher import (
|
|
get_matching_subgraph_pairs,
|
|
get_base_name_to_sets_of_related_ops,
|
|
get_type_a_related_to_b,
|
|
)
|
|
|
|
from .ns.weight_utils import (
|
|
extract_weight_from_node,
|
|
)
|
|
|
|
from .ns.graph_passes import (
|
|
remove_observers_add_loggers,
|
|
create_a_shadows_b,
|
|
)
|
|
|
|
from .ns.ns_types import (
|
|
NSSingleResultValuesType,
|
|
NSResultsType,
|
|
)
|
|
|
|
from typing import Dict, Tuple, Callable, List
|
|
|
|
RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
|
|
|
class OutputLogger(nn.Module):
|
|
stats: List[torch.Tensor]
|
|
stats_rnn: List[RNNReturnType]
|
|
|
|
def __init__(
|
|
self,
|
|
ref_node_name: str,
|
|
prev_node_name: str,
|
|
model_name: str,
|
|
ref_name: str,
|
|
prev_node_target_type: str,
|
|
results_type: str,
|
|
index_within_arg: int,
|
|
):
|
|
super().__init__()
|
|
self.stats: List[torch.Tensor] = []
|
|
self.stats_rnn: List[RNNReturnType] = []
|
|
|
|
# name of the node which was responsible for adding this logger
|
|
# Note:
|
|
# - if we are logging node outputs, this is the same as prev_node_name
|
|
# - if we are logging node inputs, this is the name of the node
|
|
# whose input this logger is logging.
|
|
#
|
|
# example, where logger1 is logging input of op1 and logger2 is logging
|
|
# the output of op1:
|
|
#
|
|
# x1 -> logger1 -> op1 -> logger2 -> x2
|
|
#
|
|
# in this example,
|
|
# - logger1's prev_node_name is x1 and ref_node_name is op1
|
|
# - logger2's prev_node_name is op1 and ref_node_name is op1
|
|
self.ref_node_name = ref_node_name
|
|
# name of the node whose output this Logger is capturing
|
|
self.prev_node_name = prev_node_name
|
|
|
|
# name of the model from which the node originated from
|
|
self.model_name = model_name
|
|
# reference name, used to match loggers from separate models
|
|
# to each other
|
|
self.ref_name = ref_name
|
|
# type of the target of the node whose output this logger is logging
|
|
self.prev_node_target_type = prev_node_target_type
|
|
# what kind of values are inside of stats
|
|
self.results_type = results_type
|
|
# index of this node within the arg of the input/output node
|
|
# for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
|
|
self.index_within_arg = index_within_arg
|
|
|
|
# Note: cannot annotate the type of x because TorchScript does not support
|
|
# the Union type.
|
|
def forward(self, x):
|
|
if isinstance(x, torch.Tensor):
|
|
self.stats.append(x.detach())
|
|
elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
|
|
new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
|
|
self.stats_rnn.append(new_res)
|
|
return x
|
|
|
|
def __repr__(self):
|
|
return f"""OutputLogger(ref_name={self.ref_name}, model_name={self.model_name},
|
|
prev_node_name={self.prev_node_name}, ref_node_name={self.ref_node_name},
|
|
results_type={self.results_type}, index_within_arg={self.index_within_arg})"""
|
|
|
|
|
|
class NSTracer(quantize_fx.QuantizationTracer):
|
|
"""
|
|
Just like a regular tracer, but treats observers and fake_quantize
|
|
modules as leaf modules.
|
|
"""
|
|
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
|
if isinstance(m, torch.quantization.ObserverBase):
|
|
return True
|
|
elif isinstance(m, torch.quantization.FakeQuantizeBase):
|
|
return True
|
|
return super().is_leaf_module(m, module_qualified_name)
|
|
|
|
|
|
def _extract_weights_one_model(
|
|
model_name: str,
|
|
model: GraphModule,
|
|
nodes_and_names_to_instrument: List[Tuple[Node, str]],
|
|
results: NSResultsType,
|
|
) -> None:
|
|
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
|
|
type_a_related_to_b = \
|
|
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
|
|
|
for node, ref_name in nodes_and_names_to_instrument:
|
|
res_type = NSSingleResultValuesType.WEIGHT.value
|
|
if ref_name not in results:
|
|
results[ref_name] = {res_type: {}}
|
|
extracted_weight = \
|
|
extract_weight_from_node(node, model, type_a_related_to_b)
|
|
if extracted_weight:
|
|
results[ref_name][res_type][model_name] = [extracted_weight]
|
|
|
|
|
|
def _extract_weights_impl(
|
|
model_name_a: str,
|
|
gm_a: GraphModule,
|
|
model_name_b: str,
|
|
gm_b: GraphModule,
|
|
) -> NSResultsType:
|
|
matched_subgraph_pairs = get_matching_subgraph_pairs(gm_a, gm_b)
|
|
|
|
# split the subgraph pairs into one data structure for each model
|
|
nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
|
|
nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
|
|
for match_name, match in matched_subgraph_pairs.items():
|
|
subgraph_a, subgraph_b = match
|
|
nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
|
|
nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
|
|
|
|
# populate the results, one model at a time
|
|
results: NSResultsType = {}
|
|
_extract_weights_one_model(
|
|
model_name_a, gm_a, nodes_and_names_to_instrument_a, results)
|
|
_extract_weights_one_model(
|
|
model_name_b, gm_b, nodes_and_names_to_instrument_b, results)
|
|
|
|
return results
|
|
|
|
|
|
def extract_weights(
|
|
model_name_a: str,
|
|
model_a: nn.Module,
|
|
model_name_b: str,
|
|
model_b: nn.Module,
|
|
) -> NSResultsType:
|
|
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
|
|
type_a_related_to_b = \
|
|
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
|
|
|
# TODO(future PR): expose these
|
|
skipped_module_names: List[str] = []
|
|
skipped_module_classes: List[Callable] = []
|
|
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
|
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
|
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
|
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
|
return _extract_weights_impl(model_name_a, gm_a, model_name_b, gm_b)
|
|
|
|
|
|
def _add_loggers_one_model(
|
|
model_name: str,
|
|
model: GraphModule,
|
|
nodes_and_names_to_instrument_inputs: List[Tuple[Node, str]],
|
|
nodes_and_names_to_instrument_outputs: List[Tuple[Node, str]],
|
|
logger_cls: Callable,
|
|
) -> nn.Module:
|
|
|
|
# TODO(future PR): do not observe nodes we do not care
|
|
# about (both fp32, denylist, etc)
|
|
node_to_instrument_inputs_to_ref_name: Dict[Node, str] = {}
|
|
node_to_instrument_outputs_to_ref_name: Dict[Node, str] = {}
|
|
for node, ref_name in nodes_and_names_to_instrument_inputs:
|
|
node_to_instrument_inputs_to_ref_name[node] = ref_name
|
|
for node, ref_name in nodes_and_names_to_instrument_outputs:
|
|
node_to_instrument_outputs_to_ref_name[node] = ref_name
|
|
|
|
model = remove_observers_add_loggers(
|
|
model, node_to_instrument_inputs_to_ref_name,
|
|
node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
|
|
return model
|
|
|
|
|
|
def _add_loggers_impl(
|
|
name_a: str,
|
|
gm_a: GraphModule,
|
|
name_b: str,
|
|
gm_b: GraphModule,
|
|
logger_cls: Callable,
|
|
should_log_inputs: bool,
|
|
) -> Tuple[nn.Module, nn.Module]:
|
|
matched_subgraph_pairs = get_matching_subgraph_pairs(gm_a, gm_b)
|
|
nodes_and_names_to_instrument_inputs_a = []
|
|
nodes_and_names_to_instrument_inputs_b = []
|
|
nodes_and_names_to_instrument_outputs_a = []
|
|
nodes_and_names_to_instrument_outputs_b = []
|
|
for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
|
|
# Note: for matching inputs we use start_node, such as observing
|
|
# the input of linear in linear-relu
|
|
if should_log_inputs:
|
|
nodes_and_names_to_instrument_inputs_a.append((subgraph_a.start_node, match_name))
|
|
nodes_and_names_to_instrument_inputs_b.append((subgraph_b.start_node, match_name))
|
|
# Note: for matching activations we always use end_node,
|
|
# such as observing the output of relu in linear-relu
|
|
nodes_and_names_to_instrument_outputs_a.append((subgraph_a.end_node, match_name))
|
|
nodes_and_names_to_instrument_outputs_b.append((subgraph_b.end_node, match_name))
|
|
|
|
new_model_a = _add_loggers_one_model(
|
|
name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
|
|
nodes_and_names_to_instrument_outputs_a, logger_cls)
|
|
new_model_b = _add_loggers_one_model(
|
|
name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
|
|
nodes_and_names_to_instrument_outputs_b, logger_cls)
|
|
return (new_model_a, new_model_b)
|
|
|
|
|
|
def add_loggers(
|
|
name_a: str,
|
|
model_a: nn.Module,
|
|
name_b: str,
|
|
model_b: nn.Module,
|
|
logger_cls: Callable,
|
|
should_log_inputs : bool = False,
|
|
) -> Tuple[nn.Module, nn.Module]:
|
|
# TODO(future PR): expose these
|
|
skipped_module_names: List[str] = []
|
|
skipped_module_classes: List[Callable] = []
|
|
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
|
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
|
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
|
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
|
return _add_loggers_impl(
|
|
name_a, gm_a, name_b, gm_b, logger_cls,
|
|
should_log_inputs=should_log_inputs)
|
|
|
|
|
|
def _extract_logger_info_one_model(
|
|
model: nn.Module,
|
|
results: NSResultsType,
|
|
logger_cls: Callable,
|
|
) -> None:
|
|
for gm_name, mod in model.named_modules():
|
|
# TODO(future PR): better check when scripted
|
|
is_logger = (
|
|
isinstance(mod, logger_cls) # type: ignore
|
|
or (
|
|
isinstance(mod, torch.jit.RecursiveScriptModule)
|
|
and mod.original_name == 'OutputLogger'
|
|
)
|
|
)
|
|
if is_logger:
|
|
key = mod.ref_name
|
|
if key not in results:
|
|
results[key] = {}
|
|
assert mod.model_name not in results[key], \
|
|
f"{mod.model_name} is already present in results"
|
|
if mod.results_type not in results[key]:
|
|
results[key][mod.results_type] = {}
|
|
if mod.model_name not in results[key][mod.results_type]:
|
|
results[key][mod.results_type][mod.model_name] = []
|
|
stats_to_use = mod.stats
|
|
if len(mod.stats_rnn) > 0:
|
|
stats_to_use = mod.stats_rnn
|
|
results[key][mod.results_type][mod.model_name].append({
|
|
'type': mod.results_type,
|
|
'values': stats_to_use,
|
|
'ref_node_name': mod.ref_node_name,
|
|
'prev_node_name': mod.prev_node_name,
|
|
'prev_node_target_type': mod.prev_node_target_type,
|
|
'index_within_arg': mod.index_within_arg,
|
|
})
|
|
# ensure the list stays sorted
|
|
results[key][mod.results_type][mod.model_name].sort(
|
|
key=lambda res: res['index_within_arg']
|
|
)
|
|
|
|
|
|
# TODO(future PR): align on naming
|
|
# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
|
|
def extract_logger_info(
|
|
model_a: nn.Module,
|
|
model_b: nn.Module,
|
|
logger_cls: Callable,
|
|
) -> NSResultsType:
|
|
"""
|
|
Same thing as ns.extract_logger_info, but for models prepared with
|
|
this module.
|
|
|
|
TODO(future PR): real docblock
|
|
|
|
Output format: NSResultsType
|
|
"""
|
|
results: NSResultsType = {}
|
|
for model in (model_a, model_b):
|
|
_extract_logger_info_one_model(model, results, logger_cls)
|
|
return results
|
|
|
|
|
|
def _add_shadow_loggers_impl(
|
|
name_a: str,
|
|
gm_a: GraphModule,
|
|
name_b: str,
|
|
gm_b: GraphModule,
|
|
logger_cls: Callable,
|
|
should_log_inputs: bool,
|
|
) -> nn.Module:
|
|
matched_subgraph_pairs = get_matching_subgraph_pairs(gm_a, gm_b)
|
|
gm_a_shadows_b = create_a_shadows_b(
|
|
name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
|
|
should_log_inputs=should_log_inputs)
|
|
return gm_a_shadows_b
|
|
|
|
|
|
def add_shadow_loggers(
|
|
name_a: str,
|
|
model_a: nn.Module,
|
|
name_b: str,
|
|
model_b: nn.Module,
|
|
logger_cls: Callable,
|
|
should_log_inputs: bool = False,
|
|
) -> nn.Module:
|
|
"""
|
|
Same thing as add_loggers, but for an `a_shadows_b` model.
|
|
TODO(future PR): real docblock
|
|
"""
|
|
# TODO(future PR): expose these
|
|
skipped_module_names: List[str] = []
|
|
skipped_module_classes: List[Callable] = []
|
|
tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
|
|
tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
|
|
gm_a = GraphModule(model_a, tracer_a.trace(model_a))
|
|
gm_b = GraphModule(model_b, tracer_b.trace(model_b))
|
|
return _add_shadow_loggers_impl(
|
|
name_a, gm_a, name_b, gm_b, logger_cls,
|
|
should_log_inputs=should_log_inputs)
|
|
|
|
|
|
def extract_shadow_logger_info(
|
|
model_a_shadows_b: nn.Module,
|
|
logger_cls: Callable,
|
|
) -> NSResultsType:
|
|
"""
|
|
Same thing as extract_logger_info, but for an `a_shadows_b` model.
|
|
TODO(future PR): real docblock
|
|
"""
|
|
results: NSResultsType = collections.defaultdict(dict)
|
|
_extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
|
|
return dict(results)
|