diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index 7808eb89257..bed1567d2b6 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -25,8 +25,8 @@ from torch.testing._internal.common_quantization import TestHelperModules from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase -def _extract_debug_handles(model) -> Dict[torch.fx.Node, int]: - debug_handle_map: Dict[torch.fx.Node, int] = {} +def _extract_debug_handles(model) -> Dict[str, int]: + debug_handle_map: Dict[str, int] = {} for node in model.graph.nodes: if ( @@ -187,3 +187,53 @@ class TestNumericDebugger(TestCase): for node_summary in comparison_results.values(): if len(node_summary.results) > 0: self.assertGreaterEqual(node_summary.results[0].sqnr, 35) + + def test_added_node_gets_unique_id(self) -> None: + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + m = capture_pre_autograd_graph(m, example_inputs) + assert isinstance(m, torch.fx.GraphModule) + generate_numeric_debug_handle(m) + ref_handles = _extract_debug_handles(m) + ref_counter = Counter(ref_handles.values()) + for k, v in ref_counter.items(): + self.assertEqual( + v, + 1, + msg=f"For handle {k}, there were {v} nodes with that handle, but expected only 1", + ) + + # Now that we have unique ids, add a new node into the graph and re-generate + # to make sure that the new node gets a unique id. + last_node = next(iter(reversed(m.graph.nodes))) + with m.graph.inserting_before(last_node): + arg = last_node.args[0] + self.assertIsInstance(arg, (list, tuple)) + arg = arg[0] + # Add a function that only requires a single tensor input. + n = m.graph.call_function(torch.ops.aten.relu.default, args=(arg,)) + arg.replace_all_uses_with(n, lambda x: x != n) + m.recompile() + + # Regenerate handles, make sure only the new relu node has a new id, and + # it doesn't clash with any of the existing ids. + generate_numeric_debug_handle(m) + handles_after_modification = _extract_debug_handles(m) + handles_counter = Counter(handles_after_modification.values()) + for name, handle in ref_handles.items(): + self.assertIn(name, handles_after_modification) + # Check that handle was unchanged. + self.assertEqual(handles_after_modification[name], handle) + # Check that total count was unchanged. + ref_count = ref_counter[handle] + after_count = handles_counter[handle] + self.assertEqual( + after_count, + ref_count, + msg=f"For handle {handle}, there were {after_count} nodes with that handle, but expected only {ref_count}", + ) + + # Check for relu specifically. Avoid hardcoding the handle id since it + # may change with future node ordering changes. + self.assertNotEqual(handles_after_modification["relu_default"], 0) + self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1) diff --git a/torch/ao/quantization/pt2e/_numeric_debugger.py b/torch/ao/quantization/pt2e/_numeric_debugger.py index fedcf470a18..3ae57acc8cb 100644 --- a/torch/ao/quantization/pt2e/_numeric_debugger.py +++ b/torch/ao/quantization/pt2e/_numeric_debugger.py @@ -1,7 +1,7 @@ import copy import logging from dataclasses import dataclass -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Optional, Sequence, Tuple import torch from torch.ao.ns.fx.utils import compute_sqnr @@ -19,7 +19,16 @@ def generate_numeric_debug_handle(graph_module: GraphModule) -> None: """Attach numeric_debug_handle_id for all nodes in the model except for placeholder node The graph nodes of input model is modified inplace. """ - unique_id = 0 + unique_id = -1 + # Find the max ID that exists in the graph first, in case part of the graph + # has already been annotated. This way we guarantee there are no duplicate + # handle IDs. + for node in graph_module.graph.nodes: + unique_id = max( + unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, -1) + ) + unique_id += 1 + for node in graph_module.graph.nodes: if node.op in ["output", "placeholder"]: continue @@ -134,6 +143,17 @@ class QuantizationComparisonResult: self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32) ) + def loss( + self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + ) -> torch.Tensor: + if self.actual.shape != self.ref.shape: + raise ValueError( + f"Cannot compare tensors with different shapes: {self.actual.shape} vs {self.ref.shape}" + ) + return loss_function( + self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32) + ) + def __repr__(self) -> str: # Don't include the tensors themselves as they are quite large to print # out. @@ -149,6 +169,10 @@ class QuantizationComparisonResult: if not isinstance(self.ref, torch.Tensor): raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}") + if self.actual.shape != self.ref.shape: + raise ValueError( + f"Cannot compare tensors with different shapes: ref={self.ref.shape} vs actual={self.actual.shape}" + ) @dataclass(frozen=True) @@ -197,8 +221,8 @@ def extract_results_from_loggers( def compare_results( - ref_results: Dict[int, Tuple[str, object, List[torch.Tensor]]], - actual_results: Dict[int, Tuple[str, object, List[torch.Tensor]]], + ref_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]], + actual_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]], ) -> Dict[int, NodeAccuracySummary]: """Given two dict mapping from `debug_handle_id` (int) to list of tensors return a map from `debug_handle_id` to `NodeAccuracySummary` that contains @@ -220,16 +244,25 @@ def compare_results( ) continue actual_name, actual_stack, actual_stats = actual_results[debug_handle] - comparisons[debug_handle] = NodeAccuracySummary( - handle=debug_handle, - actual_node_name=actual_name, - actual_module_stack=_module_stack_to_str(actual_stack), - ref_node_name=ref_name, - ref_module_stack=_module_stack_to_str(ref_stack), - results=[ + try: + results = [ QuantizationComparisonResult(actual=a, ref=b) for a, b in zip(actual_stats, ref_stats) - ], + ] + except Exception as e: + # Add extra information for an exception from QuantizationComparisonResult + # if the shapes didn't match, to include the handle and the node names. + raise ValueError( + f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}" + ) from e + + comparisons[debug_handle] = NodeAccuracySummary( + handle=debug_handle, + actual_node_name=actual_name or "", + actual_module_stack=_module_stack_to_str(actual_stack), + ref_node_name=ref_name or "", + ref_module_stack=_module_stack_to_str(ref_stack), + results=results, ) return comparisons