mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Use dataclass features in two classes (#164221)
This PR completes two TODO items by using features of `dataclass`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164221 Approved by: https://github.com/Skylion007, https://github.com/mlazos Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
This commit is contained in:
parent
591997490a
commit
fa90090735
|
|
@ -30,7 +30,7 @@ class CallbackTests(TestCase):
|
||||||
|
|
||||||
def test_callbacks_with_duplicate_prevention(self) -> None:
|
def test_callbacks_with_duplicate_prevention(self) -> None:
|
||||||
trigger = CallbackTrigger.DYNAMO
|
trigger = CallbackTrigger.DYNAMO
|
||||||
compile_id = CompileId(0, 0)
|
compile_id = CompileId(frame_id=0, frame_compile_id=0)
|
||||||
with (
|
with (
|
||||||
callback_handler.install_callbacks(trigger, compile_id),
|
callback_handler.install_callbacks(trigger, compile_id),
|
||||||
callback_handler.install_callbacks(trigger, compile_id),
|
callback_handler.install_callbacks(trigger, compile_id),
|
||||||
|
|
@ -40,7 +40,7 @@ class CallbackTests(TestCase):
|
||||||
|
|
||||||
def test_counter(self) -> None:
|
def test_counter(self) -> None:
|
||||||
trigger = CallbackTrigger.DYNAMO
|
trigger = CallbackTrigger.DYNAMO
|
||||||
compile_id = CompileId(0, 0)
|
compile_id = CompileId(frame_id=0, frame_compile_id=0)
|
||||||
with callback_handler.install_callbacks(trigger, compile_id):
|
with callback_handler.install_callbacks(trigger, compile_id):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
callback_handler._CompilationCallbackHandler__pending_callbacks_counter,
|
callback_handler._CompilationCallbackHandler__pending_callbacks_counter,
|
||||||
|
|
@ -56,7 +56,7 @@ class CallbackTests(TestCase):
|
||||||
AssertionError, "Pending callbacks counter cannot become negative."
|
AssertionError, "Pending callbacks counter cannot become negative."
|
||||||
):
|
):
|
||||||
trigger = CallbackTrigger.DYNAMO
|
trigger = CallbackTrigger.DYNAMO
|
||||||
compile_id = CompileId(0, 0)
|
compile_id = CompileId(frame_id=0, frame_compile_id=0)
|
||||||
with callback_handler.install_callbacks(trigger, str(compile_id)):
|
with callback_handler.install_callbacks(trigger, str(compile_id)):
|
||||||
pass
|
pass
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
|
|
||||||
|
|
@ -95,7 +95,11 @@ class FrameInitTests(torch._dynamo.test_case.TestCase):
|
||||||
transformed_code = code_map1[frame.f_code]
|
transformed_code = code_map1[frame.f_code]
|
||||||
return wrap_guarded_code(
|
return wrap_guarded_code(
|
||||||
GuardedCode(
|
GuardedCode(
|
||||||
transformed_code, empty_guard_manager, CompileId(None, 0, 0)
|
transformed_code,
|
||||||
|
empty_guard_manager,
|
||||||
|
CompileId(
|
||||||
|
frame_id=None, frame_compile_id=0, compiled_autograd_id=0
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return ConvertFrameReturn()
|
return ConvertFrameReturn()
|
||||||
|
|
@ -105,7 +109,11 @@ class FrameInitTests(torch._dynamo.test_case.TestCase):
|
||||||
transformed_code = code_map2[frame.f_code]
|
transformed_code = code_map2[frame.f_code]
|
||||||
return wrap_guarded_code(
|
return wrap_guarded_code(
|
||||||
GuardedCode(
|
GuardedCode(
|
||||||
transformed_code, empty_guard_manager, CompileId(None, 0, 0)
|
transformed_code,
|
||||||
|
empty_guard_manager,
|
||||||
|
CompileId(
|
||||||
|
frame_id=None, frame_compile_id=0, compiled_autograd_id=0
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return ConvertFrameReturn()
|
return ConvertFrameReturn()
|
||||||
|
|
|
||||||
|
|
@ -329,7 +329,9 @@ class TestGuardSerializationBase(torch._inductor.test_case.TestCase):
|
||||||
package=None,
|
package=None,
|
||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
compile_context(CompileContext(CompileId(0, 0))),
|
compile_context(
|
||||||
|
CompileContext(CompileId(frame_id=0, frame_compile_id=0))
|
||||||
|
),
|
||||||
tracing(tracer.output.tracing_context),
|
tracing(tracer.output.tracing_context),
|
||||||
tracer.set_current_tx(),
|
tracer.set_current_tx(),
|
||||||
get_metrics_context(),
|
get_metrics_context(),
|
||||||
|
|
|
||||||
|
|
@ -6864,7 +6864,9 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||||
with patch.object(
|
with patch.object(
|
||||||
CompileContext,
|
CompileContext,
|
||||||
"__init__",
|
"__init__",
|
||||||
lambda self, _: CompileContext_init(self, CompileId(999, 999)),
|
lambda self, _: CompileContext_init(
|
||||||
|
self, CompileId(frame_id=999, frame_compile_id=999)
|
||||||
|
),
|
||||||
):
|
):
|
||||||
_, (coda_a2,) = _run_and_get_stripped_kernels(a, x)
|
_, (coda_a2,) = _run_and_get_stripped_kernels(a, x)
|
||||||
_, (coda_c2,) = _run_and_get_stripped_kernels(c, x)
|
_, (coda_c2,) = _run_and_get_stripped_kernels(c, x)
|
||||||
|
|
|
||||||
|
|
@ -72,8 +72,7 @@ CA_COMPILE_ID_PATTERN = re.compile(
|
||||||
# 3. Compact: The string form is directly displayed by some tools. Special symbols are okay.
|
# 3. Compact: The string form is directly displayed by some tools. Special symbols are okay.
|
||||||
|
|
||||||
|
|
||||||
# TODO: mark as kw_only=True once we drop support for <Python 3.10
|
@dataclass(frozen=True, kw_only=True, slots=True)
|
||||||
@dataclass(frozen=True)
|
|
||||||
class CompileId:
|
class CompileId:
|
||||||
frame_id: Optional[int]
|
frame_id: Optional[int]
|
||||||
# This id is per-frame, and counts how many times we've compiled this
|
# This id is per-frame, and counts how many times we've compiled this
|
||||||
|
|
|
||||||
|
|
@ -1041,10 +1041,10 @@ def maybe_estimate_runtime_benchmark(snode: BaseSchedulerNode) -> Optional[float
|
||||||
return ms
|
return ms
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(slots=True)
|
||||||
class WhyNoFuse:
|
class WhyNoFuse:
|
||||||
# TODO when we drop support for Python < 3.10, we can use
|
name1: str
|
||||||
# @dataclass(slots=True) instead of manually specifying __slots__.
|
name2: str
|
||||||
__slots__ = ["name1", "name2", "reason", "args"]
|
|
||||||
reason: str
|
reason: str
|
||||||
args: tuple[Any, ...]
|
args: tuple[Any, ...]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user