mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[torch][ao] Add customizable loss function to NodeAccuracySummary (#136282)
Summary: Add a customizable loss function callback to NodeAccuracySummary to allow users to pass in their own loss function. Also, fix some type errors and propagate better exception messages when unexpected tensor comparisons occur. Finally, enhance the robustness of `generate_numeric_debug_handle` in the case where it is called multiple times on the same model, by avoiding reuse of the same IDs. Test Plan: Added a test for this case in `test_numeric_debugger`. Reviewed By: jerryzh168 Differential Revision: D62898297 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136282 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
e09c5b6046
commit
3be150653c
|
|
@ -25,8 +25,8 @@ from torch.testing._internal.common_quantization import TestHelperModules
|
|||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase
|
||||
|
||||
|
||||
def _extract_debug_handles(model) -> Dict[torch.fx.Node, int]:
|
||||
debug_handle_map: Dict[torch.fx.Node, int] = {}
|
||||
def _extract_debug_handles(model) -> Dict[str, int]:
|
||||
debug_handle_map: Dict[str, int] = {}
|
||||
|
||||
for node in model.graph.nodes:
|
||||
if (
|
||||
|
|
@ -187,3 +187,53 @@ class TestNumericDebugger(TestCase):
|
|||
for node_summary in comparison_results.values():
|
||||
if len(node_summary.results) > 0:
|
||||
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
|
||||
|
||||
def test_added_node_gets_unique_id(self) -> None:
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
m = capture_pre_autograd_graph(m, example_inputs)
|
||||
assert isinstance(m, torch.fx.GraphModule)
|
||||
generate_numeric_debug_handle(m)
|
||||
ref_handles = _extract_debug_handles(m)
|
||||
ref_counter = Counter(ref_handles.values())
|
||||
for k, v in ref_counter.items():
|
||||
self.assertEqual(
|
||||
v,
|
||||
1,
|
||||
msg=f"For handle {k}, there were {v} nodes with that handle, but expected only 1",
|
||||
)
|
||||
|
||||
# Now that we have unique ids, add a new node into the graph and re-generate
|
||||
# to make sure that the new node gets a unique id.
|
||||
last_node = next(iter(reversed(m.graph.nodes)))
|
||||
with m.graph.inserting_before(last_node):
|
||||
arg = last_node.args[0]
|
||||
self.assertIsInstance(arg, (list, tuple))
|
||||
arg = arg[0]
|
||||
# Add a function that only requires a single tensor input.
|
||||
n = m.graph.call_function(torch.ops.aten.relu.default, args=(arg,))
|
||||
arg.replace_all_uses_with(n, lambda x: x != n)
|
||||
m.recompile()
|
||||
|
||||
# Regenerate handles, make sure only the new relu node has a new id, and
|
||||
# it doesn't clash with any of the existing ids.
|
||||
generate_numeric_debug_handle(m)
|
||||
handles_after_modification = _extract_debug_handles(m)
|
||||
handles_counter = Counter(handles_after_modification.values())
|
||||
for name, handle in ref_handles.items():
|
||||
self.assertIn(name, handles_after_modification)
|
||||
# Check that handle was unchanged.
|
||||
self.assertEqual(handles_after_modification[name], handle)
|
||||
# Check that total count was unchanged.
|
||||
ref_count = ref_counter[handle]
|
||||
after_count = handles_counter[handle]
|
||||
self.assertEqual(
|
||||
after_count,
|
||||
ref_count,
|
||||
msg=f"For handle {handle}, there were {after_count} nodes with that handle, but expected only {ref_count}",
|
||||
)
|
||||
|
||||
# Check for relu specifically. Avoid hardcoding the handle id since it
|
||||
# may change with future node ordering changes.
|
||||
self.assertNotEqual(handles_after_modification["relu_default"], 0)
|
||||
self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import copy
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch.ao.ns.fx.utils import compute_sqnr
|
||||
|
|
@ -19,7 +19,16 @@ def generate_numeric_debug_handle(graph_module: GraphModule) -> None:
|
|||
"""Attach numeric_debug_handle_id for all nodes in the model except for placeholder node
|
||||
The graph nodes of input model is modified inplace.
|
||||
"""
|
||||
unique_id = 0
|
||||
unique_id = -1
|
||||
# Find the max ID that exists in the graph first, in case part of the graph
|
||||
# has already been annotated. This way we guarantee there are no duplicate
|
||||
# handle IDs.
|
||||
for node in graph_module.graph.nodes:
|
||||
unique_id = max(
|
||||
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, -1)
|
||||
)
|
||||
unique_id += 1
|
||||
|
||||
for node in graph_module.graph.nodes:
|
||||
if node.op in ["output", "placeholder"]:
|
||||
continue
|
||||
|
|
@ -134,6 +143,17 @@ class QuantizationComparisonResult:
|
|||
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
|
||||
)
|
||||
|
||||
def loss(
|
||||
self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
if self.actual.shape != self.ref.shape:
|
||||
raise ValueError(
|
||||
f"Cannot compare tensors with different shapes: {self.actual.shape} vs {self.ref.shape}"
|
||||
)
|
||||
return loss_function(
|
||||
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Don't include the tensors themselves as they are quite large to print
|
||||
# out.
|
||||
|
|
@ -149,6 +169,10 @@ class QuantizationComparisonResult:
|
|||
|
||||
if not isinstance(self.ref, torch.Tensor):
|
||||
raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}")
|
||||
if self.actual.shape != self.ref.shape:
|
||||
raise ValueError(
|
||||
f"Cannot compare tensors with different shapes: ref={self.ref.shape} vs actual={self.actual.shape}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -197,8 +221,8 @@ def extract_results_from_loggers(
|
|||
|
||||
|
||||
def compare_results(
|
||||
ref_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
|
||||
actual_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
|
||||
ref_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]],
|
||||
actual_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]],
|
||||
) -> Dict[int, NodeAccuracySummary]:
|
||||
"""Given two dict mapping from `debug_handle_id` (int) to list of tensors
|
||||
return a map from `debug_handle_id` to `NodeAccuracySummary` that contains
|
||||
|
|
@ -220,16 +244,25 @@ def compare_results(
|
|||
)
|
||||
continue
|
||||
actual_name, actual_stack, actual_stats = actual_results[debug_handle]
|
||||
comparisons[debug_handle] = NodeAccuracySummary(
|
||||
handle=debug_handle,
|
||||
actual_node_name=actual_name,
|
||||
actual_module_stack=_module_stack_to_str(actual_stack),
|
||||
ref_node_name=ref_name,
|
||||
ref_module_stack=_module_stack_to_str(ref_stack),
|
||||
try:
|
||||
results = [
|
||||
QuantizationComparisonResult(actual=a, ref=b)
|
||||
for a, b in zip(actual_stats, ref_stats)
|
||||
],
|
||||
]
|
||||
except Exception as e:
|
||||
# Add extra information for an exception from QuantizationComparisonResult
|
||||
# if the shapes didn't match, to include the handle and the node names.
|
||||
raise ValueError(
|
||||
f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}"
|
||||
) from e
|
||||
|
||||
comparisons[debug_handle] = NodeAccuracySummary(
|
||||
handle=debug_handle,
|
||||
actual_node_name=actual_name or "",
|
||||
actual_module_stack=_module_stack_to_str(actual_stack),
|
||||
ref_node_name=ref_name or "",
|
||||
ref_module_stack=_module_stack_to_str(ref_stack),
|
||||
results=results,
|
||||
)
|
||||
|
||||
return comparisons
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user