From fa90090735be36b16e7892997fd59811acce456a Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Wed, 1 Oct 2025 03:20:39 +0000 Subject: [PATCH] Use dataclass features in two classes (#164221) This PR completes two TODO items by using features of `dataclass`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164221 Approved by: https://github.com/Skylion007, https://github.com/mlazos Co-authored-by: Aaron Gokaslan --- test/dynamo/test_callback.py | 6 +++--- test/dynamo/test_frame_init.py | 12 ++++++++++-- test/dynamo/test_guard_serialization.py | 4 +++- test/inductor/test_torchinductor.py | 4 +++- torch/_guards.py | 3 +-- torch/_inductor/scheduler.py | 6 +++--- 6 files changed, 23 insertions(+), 12 deletions(-) diff --git a/test/dynamo/test_callback.py b/test/dynamo/test_callback.py index 56ff5ff41e6..88151571746 100644 --- a/test/dynamo/test_callback.py +++ b/test/dynamo/test_callback.py @@ -30,7 +30,7 @@ class CallbackTests(TestCase): def test_callbacks_with_duplicate_prevention(self) -> None: trigger = CallbackTrigger.DYNAMO - compile_id = CompileId(0, 0) + compile_id = CompileId(frame_id=0, frame_compile_id=0) with ( callback_handler.install_callbacks(trigger, compile_id), callback_handler.install_callbacks(trigger, compile_id), @@ -40,7 +40,7 @@ class CallbackTests(TestCase): def test_counter(self) -> None: trigger = CallbackTrigger.DYNAMO - compile_id = CompileId(0, 0) + compile_id = CompileId(frame_id=0, frame_compile_id=0) with callback_handler.install_callbacks(trigger, compile_id): self.assertEqual( callback_handler._CompilationCallbackHandler__pending_callbacks_counter, @@ -56,7 +56,7 @@ class CallbackTests(TestCase): AssertionError, "Pending callbacks counter cannot become negative." ): trigger = CallbackTrigger.DYNAMO - compile_id = CompileId(0, 0) + compile_id = CompileId(frame_id=0, frame_compile_id=0) with callback_handler.install_callbacks(trigger, str(compile_id)): pass self.assertEqual( diff --git a/test/dynamo/test_frame_init.py b/test/dynamo/test_frame_init.py index 59fdb20b71f..20cebe9e700 100644 --- a/test/dynamo/test_frame_init.py +++ b/test/dynamo/test_frame_init.py @@ -95,7 +95,11 @@ class FrameInitTests(torch._dynamo.test_case.TestCase): transformed_code = code_map1[frame.f_code] return wrap_guarded_code( GuardedCode( - transformed_code, empty_guard_manager, CompileId(None, 0, 0) + transformed_code, + empty_guard_manager, + CompileId( + frame_id=None, frame_compile_id=0, compiled_autograd_id=0 + ), ) ) return ConvertFrameReturn() @@ -105,7 +109,11 @@ class FrameInitTests(torch._dynamo.test_case.TestCase): transformed_code = code_map2[frame.f_code] return wrap_guarded_code( GuardedCode( - transformed_code, empty_guard_manager, CompileId(None, 0, 0) + transformed_code, + empty_guard_manager, + CompileId( + frame_id=None, frame_compile_id=0, compiled_autograd_id=0 + ), ) ) return ConvertFrameReturn() diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index 7e19de29734..96cc26bafd4 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -329,7 +329,9 @@ class TestGuardSerializationBase(torch._inductor.test_case.TestCase): package=None, ) with ( - compile_context(CompileContext(CompileId(0, 0))), + compile_context( + CompileContext(CompileId(frame_id=0, frame_compile_id=0)) + ), tracing(tracer.output.tracing_context), tracer.set_current_tx(), get_metrics_context(), diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 4dd372ed7d7..b7d2c250830 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -6864,7 +6864,9 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar with patch.object( CompileContext, "__init__", - lambda self, _: CompileContext_init(self, CompileId(999, 999)), + lambda self, _: CompileContext_init( + self, CompileId(frame_id=999, frame_compile_id=999) + ), ): _, (coda_a2,) = _run_and_get_stripped_kernels(a, x) _, (coda_c2,) = _run_and_get_stripped_kernels(c, x) diff --git a/torch/_guards.py b/torch/_guards.py index 4ac926d9e45..76a35d1060e 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -72,8 +72,7 @@ CA_COMPILE_ID_PATTERN = re.compile( # 3. Compact: The string form is directly displayed by some tools. Special symbols are okay. -# TODO: mark as kw_only=True once we drop support for Optional[float return ms +@dataclasses.dataclass(slots=True) class WhyNoFuse: - # TODO when we drop support for Python < 3.10, we can use - # @dataclass(slots=True) instead of manually specifying __slots__. - __slots__ = ["name1", "name2", "reason", "args"] + name1: str + name2: str reason: str args: tuple[Any, ...]