From 3be150653cedb51f4588b15a46dc0522de672f9f Mon Sep 17 00:00:00 2001 From: Riley Dulin Date: Tue, 24 Sep 2024 03:28:09 +0000 Subject: [PATCH] [torch][ao] Add customizable loss function to NodeAccuracySummary (#136282) Summary: Add a customizable loss function callback to NodeAccuracySummary to allow users to pass in their own loss function. Also, fix some type errors and propagate better exception messages when unexpected tensor comparisons occur. Finally, enhance the robustness of `generate_numeric_debug_handle` in the case where it is called multiple times on the same model, by avoiding reuse of the same IDs. Test Plan: Added a test for this case in `test_numeric_debugger`. Reviewed By: jerryzh168 Differential Revision: D62898297 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136282 Approved by: https://github.com/jerryzh168 --- .../pt2e/test_numeric_debugger.py | 54 +++++++++++++++++- .../ao/quantization/pt2e/_numeric_debugger.py | 57 +++++++++++++++---- 2 files changed, 97 insertions(+), 14 deletions(-) diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index 7808eb89257..bed1567d2b6 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -25,8 +25,8 @@ from torch.testing._internal.common_quantization import TestHelperModules from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase -def _extract_debug_handles(model) -> Dict[torch.fx.Node, int]: - debug_handle_map: Dict[torch.fx.Node, int] = {} +def _extract_debug_handles(model) -> Dict[str, int]: + debug_handle_map: Dict[str, int] = {} for node in model.graph.nodes: if ( @@ -187,3 +187,53 @@ class TestNumericDebugger(TestCase): for node_summary in comparison_results.values(): if len(node_summary.results) > 0: self.assertGreaterEqual(node_summary.results[0].sqnr, 35) + + def test_added_node_gets_unique_id(self) -> None: + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + m = capture_pre_autograd_graph(m, example_inputs) + assert isinstance(m, torch.fx.GraphModule) + generate_numeric_debug_handle(m) + ref_handles = _extract_debug_handles(m) + ref_counter = Counter(ref_handles.values()) + for k, v in ref_counter.items(): + self.assertEqual( + v, + 1, + msg=f"For handle {k}, there were {v} nodes with that handle, but expected only 1", + ) + + # Now that we have unique ids, add a new node into the graph and re-generate + # to make sure that the new node gets a unique id. + last_node = next(iter(reversed(m.graph.nodes))) + with m.graph.inserting_before(last_node): + arg = last_node.args[0] + self.assertIsInstance(arg, (list, tuple)) + arg = arg[0] + # Add a function that only requires a single tensor input. + n = m.graph.call_function(torch.ops.aten.relu.default, args=(arg,)) + arg.replace_all_uses_with(n, lambda x: x != n) + m.recompile() + + # Regenerate handles, make sure only the new relu node has a new id, and + # it doesn't clash with any of the existing ids. + generate_numeric_debug_handle(m) + handles_after_modification = _extract_debug_handles(m) + handles_counter = Counter(handles_after_modification.values()) + for name, handle in ref_handles.items(): + self.assertIn(name, handles_after_modification) + # Check that handle was unchanged. + self.assertEqual(handles_after_modification[name], handle) + # Check that total count was unchanged. + ref_count = ref_counter[handle] + after_count = handles_counter[handle] + self.assertEqual( + after_count, + ref_count, + msg=f"For handle {handle}, there were {after_count} nodes with that handle, but expected only {ref_count}", + ) + + # Check for relu specifically. Avoid hardcoding the handle id since it + # may change with future node ordering changes. + self.assertNotEqual(handles_after_modification["relu_default"], 0) + self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1) diff --git a/torch/ao/quantization/pt2e/_numeric_debugger.py b/torch/ao/quantization/pt2e/_numeric_debugger.py index fedcf470a18..3ae57acc8cb 100644 --- a/torch/ao/quantization/pt2e/_numeric_debugger.py +++ b/torch/ao/quantization/pt2e/_numeric_debugger.py @@ -1,7 +1,7 @@ import copy import logging from dataclasses import dataclass -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Optional, Sequence, Tuple import torch from torch.ao.ns.fx.utils import compute_sqnr @@ -19,7 +19,16 @@ def generate_numeric_debug_handle(graph_module: GraphModule) -> None: """Attach numeric_debug_handle_id for all nodes in the model except for placeholder node The graph nodes of input model is modified inplace. """ - unique_id = 0 + unique_id = -1 + # Find the max ID that exists in the graph first, in case part of the graph + # has already been annotated. This way we guarantee there are no duplicate + # handle IDs. + for node in graph_module.graph.nodes: + unique_id = max( + unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, -1) + ) + unique_id += 1 + for node in graph_module.graph.nodes: if node.op in ["output", "placeholder"]: continue @@ -134,6 +143,17 @@ class QuantizationComparisonResult: self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32) ) + def loss( + self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + ) -> torch.Tensor: + if self.actual.shape != self.ref.shape: + raise ValueError( + f"Cannot compare tensors with different shapes: {self.actual.shape} vs {self.ref.shape}" + ) + return loss_function( + self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32) + ) + def __repr__(self) -> str: # Don't include the tensors themselves as they are quite large to print # out. @@ -149,6 +169,10 @@ class QuantizationComparisonResult: if not isinstance(self.ref, torch.Tensor): raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}") + if self.actual.shape != self.ref.shape: + raise ValueError( + f"Cannot compare tensors with different shapes: ref={self.ref.shape} vs actual={self.actual.shape}" + ) @dataclass(frozen=True) @@ -197,8 +221,8 @@ def extract_results_from_loggers( def compare_results( - ref_results: Dict[int, Tuple[str, object, List[torch.Tensor]]], - actual_results: Dict[int, Tuple[str, object, List[torch.Tensor]]], + ref_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]], + actual_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]], ) -> Dict[int, NodeAccuracySummary]: """Given two dict mapping from `debug_handle_id` (int) to list of tensors return a map from `debug_handle_id` to `NodeAccuracySummary` that contains @@ -220,16 +244,25 @@ def compare_results( ) continue actual_name, actual_stack, actual_stats = actual_results[debug_handle] - comparisons[debug_handle] = NodeAccuracySummary( - handle=debug_handle, - actual_node_name=actual_name, - actual_module_stack=_module_stack_to_str(actual_stack), - ref_node_name=ref_name, - ref_module_stack=_module_stack_to_str(ref_stack), - results=[ + try: + results = [ QuantizationComparisonResult(actual=a, ref=b) for a, b in zip(actual_stats, ref_stats) - ], + ] + except Exception as e: + # Add extra information for an exception from QuantizationComparisonResult + # if the shapes didn't match, to include the handle and the node names. + raise ValueError( + f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}" + ) from e + + comparisons[debug_handle] = NodeAccuracySummary( + handle=debug_handle, + actual_node_name=actual_name or "", + actual_module_stack=_module_stack_to_str(actual_stack), + ref_node_name=ref_name or "", + ref_module_stack=_module_stack_to_str(ref_stack), + results=results, ) return comparisons