[torch][ao] Add customizable loss function to NodeAccuracySummary (#136282)

Summary:
Add a customizable loss function callback to NodeAccuracySummary to
allow users to pass in their own loss function.

Also, fix some type errors and propagate better exception messages when
unexpected tensor comparisons occur. Finally, enhance the robustness of
`generate_numeric_debug_handle` in the case where it is called multiple
times on the same model, by avoiding reuse of the same IDs.

Test Plan: Added a test for this case in `test_numeric_debugger`.

Reviewed By: jerryzh168

Differential Revision: D62898297

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136282
Approved by: https://github.com/jerryzh168
This commit is contained in:
Riley Dulin 2024-09-24 03:28:09 +00:00 committed by PyTorch MergeBot
parent e09c5b6046
commit 3be150653c
2 changed files with 97 additions and 14 deletions

View File

@ -25,8 +25,8 @@ from torch.testing._internal.common_quantization import TestHelperModules
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase
def _extract_debug_handles(model) -> Dict[torch.fx.Node, int]:
debug_handle_map: Dict[torch.fx.Node, int] = {}
def _extract_debug_handles(model) -> Dict[str, int]:
debug_handle_map: Dict[str, int] = {}
for node in model.graph.nodes:
if (
@ -187,3 +187,53 @@ class TestNumericDebugger(TestCase):
for node_summary in comparison_results.values():
if len(node_summary.results) > 0:
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
def test_added_node_gets_unique_id(self) -> None:
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
m = capture_pre_autograd_graph(m, example_inputs)
assert isinstance(m, torch.fx.GraphModule)
generate_numeric_debug_handle(m)
ref_handles = _extract_debug_handles(m)
ref_counter = Counter(ref_handles.values())
for k, v in ref_counter.items():
self.assertEqual(
v,
1,
msg=f"For handle {k}, there were {v} nodes with that handle, but expected only 1",
)
# Now that we have unique ids, add a new node into the graph and re-generate
# to make sure that the new node gets a unique id.
last_node = next(iter(reversed(m.graph.nodes)))
with m.graph.inserting_before(last_node):
arg = last_node.args[0]
self.assertIsInstance(arg, (list, tuple))
arg = arg[0]
# Add a function that only requires a single tensor input.
n = m.graph.call_function(torch.ops.aten.relu.default, args=(arg,))
arg.replace_all_uses_with(n, lambda x: x != n)
m.recompile()
# Regenerate handles, make sure only the new relu node has a new id, and
# it doesn't clash with any of the existing ids.
generate_numeric_debug_handle(m)
handles_after_modification = _extract_debug_handles(m)
handles_counter = Counter(handles_after_modification.values())
for name, handle in ref_handles.items():
self.assertIn(name, handles_after_modification)
# Check that handle was unchanged.
self.assertEqual(handles_after_modification[name], handle)
# Check that total count was unchanged.
ref_count = ref_counter[handle]
after_count = handles_counter[handle]
self.assertEqual(
after_count,
ref_count,
msg=f"For handle {handle}, there were {after_count} nodes with that handle, but expected only {ref_count}",
)
# Check for relu specifically. Avoid hardcoding the handle id since it
# may change with future node ordering changes.
self.assertNotEqual(handles_after_modification["relu_default"], 0)
self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1)

View File

@ -1,7 +1,7 @@
import copy
import logging
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Callable, Dict, List, Optional, Sequence, Tuple
import torch
from torch.ao.ns.fx.utils import compute_sqnr
@ -19,7 +19,16 @@ def generate_numeric_debug_handle(graph_module: GraphModule) -> None:
"""Attach numeric_debug_handle_id for all nodes in the model except for placeholder node
The graph nodes of input model is modified inplace.
"""
unique_id = 0
unique_id = -1
# Find the max ID that exists in the graph first, in case part of the graph
# has already been annotated. This way we guarantee there are no duplicate
# handle IDs.
for node in graph_module.graph.nodes:
unique_id = max(
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, -1)
)
unique_id += 1
for node in graph_module.graph.nodes:
if node.op in ["output", "placeholder"]:
continue
@ -134,6 +143,17 @@ class QuantizationComparisonResult:
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
)
def loss(
self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) -> torch.Tensor:
if self.actual.shape != self.ref.shape:
raise ValueError(
f"Cannot compare tensors with different shapes: {self.actual.shape} vs {self.ref.shape}"
)
return loss_function(
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
)
def __repr__(self) -> str:
# Don't include the tensors themselves as they are quite large to print
# out.
@ -149,6 +169,10 @@ class QuantizationComparisonResult:
if not isinstance(self.ref, torch.Tensor):
raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}")
if self.actual.shape != self.ref.shape:
raise ValueError(
f"Cannot compare tensors with different shapes: ref={self.ref.shape} vs actual={self.actual.shape}"
)
@dataclass(frozen=True)
@ -197,8 +221,8 @@ def extract_results_from_loggers(
def compare_results(
ref_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
actual_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
ref_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]],
actual_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]],
) -> Dict[int, NodeAccuracySummary]:
"""Given two dict mapping from `debug_handle_id` (int) to list of tensors
return a map from `debug_handle_id` to `NodeAccuracySummary` that contains
@ -220,16 +244,25 @@ def compare_results(
)
continue
actual_name, actual_stack, actual_stats = actual_results[debug_handle]
comparisons[debug_handle] = NodeAccuracySummary(
handle=debug_handle,
actual_node_name=actual_name,
actual_module_stack=_module_stack_to_str(actual_stack),
ref_node_name=ref_name,
ref_module_stack=_module_stack_to_str(ref_stack),
results=[
try:
results = [
QuantizationComparisonResult(actual=a, ref=b)
for a, b in zip(actual_stats, ref_stats)
],
]
except Exception as e:
# Add extra information for an exception from QuantizationComparisonResult
# if the shapes didn't match, to include the handle and the node names.
raise ValueError(
f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}"
) from e
comparisons[debug_handle] = NodeAccuracySummary(
handle=debug_handle,
actual_node_name=actual_name or "",
actual_module_stack=_module_stack_to_str(actual_stack),
ref_node_name=ref_name or "",
ref_module_stack=_module_stack_to_str(ref_stack),
results=results,
)
return comparisons