[Dynamo] Cleanup state management for ctx managers (#149689)

Removes state indirection for ctx managers. This isn't needed anymore since VTs are mutable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149689
Approved by: https://github.com/StrongerXi
This commit is contained in:
Michael Lazos 2025-03-20 17:40:38 -07:00 committed by PyTorch MergeBot
parent cfc08caea9
commit 34743678b9

View File

@ -20,11 +20,10 @@ consistency between eager execution and compiled graph behavior by capturing and
restoring state changes.
"""
import dataclasses
import inspect
import sys
import warnings
from typing import Callable, Optional, TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Union
import torch._C
from torch._guards import Guard
@ -54,27 +53,6 @@ if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
@dataclasses.dataclass
class ContextManagerState:
"""
Mutating `self` in VariableTracker is not allowed because we copy
them. This is a mutable container pointed to by context managers
that won't get copied, so it is safe to mutate.
"""
cleanup_fn: Optional[Callable] = None
proxy: Optional[torch.fx.Proxy] = None
def cleanup(self):
if self.cleanup_fn is not None:
self.cleanup_fn()
self.cleanup_fn = None
def cleanup_assert(self):
assert self.cleanup_fn, "multiple exits?"
self.cleanup()
class ContextWrappingVariable(VariableTracker):
_nonvar_fields = {
"cm_obj",
@ -84,13 +62,10 @@ class ContextWrappingVariable(VariableTracker):
*VariableTracker._nonvar_fields,
}
def __init__(
self, target_values, initial_values=None, *, state=None, **kwargs
) -> None:
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
super().__init__(**kwargs)
self.target_values = target_values
self.initial_values = initial_values
self.state = ContextManagerState() if state is None else state
def enter(self, tx):
self._call_func(tx, self.target_values)
@ -103,11 +78,11 @@ class ContextWrappingVariable(VariableTracker):
def fn():
self._call_func(tx, self.initial_values)
self.state.cleanup_fn = fn
tx.output.add_cleanup_hook(self.state.cleanup)
self.cleanup_fn = fn
tx.output.add_cleanup_hook(self.cleanup)
def exit(self, tx: "InstructionTranslator", *args):
self.state.cleanup_assert()
self.cleanup_assert()
return variables.ConstantVariable.create(None)
def reconstruct_type(self, codegen):
@ -152,6 +127,15 @@ class ContextWrappingVariable(VariableTracker):
def exit_on_graph_break(self):
return True
def cleanup(self):
if self.cleanup_fn is not None:
self.cleanup_fn()
self.cleanup_fn = None
def cleanup_assert(self):
assert self.cleanup_fn, "multiple exits?"
self.cleanup()
class GenericContextWrappingVariable(UserDefinedObjectVariable):
# Some methods in ContextWrappingVariable assumes the arguments are
@ -217,7 +201,7 @@ class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
self.prev_state
),
)
self.state.proxy = tx.output.create_node(
self.proxy = tx.output.create_node(
"call_function",
torch._C._functorch.set_inplace_requires_grad_allowed,
(enabled,),
@ -226,7 +210,7 @@ class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
return variables.ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
self.state.cleanup()
self.cleanup()
tx.output.create_node(
"call_function",
torch._C._functorch.set_inplace_requires_grad_allowed,
@ -253,7 +237,7 @@ class TemporarilyPopInterpreterStackCtxManagerVariable(ContextWrappingVariable):
tx,
lambda: torch._C._functorch.push_dynamic_layer_stack(self.saved),
)
self.state.proxy = tx.output.create_node(
self.proxy = tx.output.create_node(
"call_function",
torch._C._functorch.pop_dynamic_layer_stack,
(),
@ -262,11 +246,11 @@ class TemporarilyPopInterpreterStackCtxManagerVariable(ContextWrappingVariable):
return variables.ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
self.state.cleanup()
self.cleanup()
tx.output.create_node(
"call_function",
torch._C._functorch.push_dynamic_layer_stack,
(self.state.proxy,),
(self.proxy,),
{},
)
return variables.ConstantVariable.create(None)
@ -297,7 +281,7 @@ class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable):
self.set_cleanup_hook(
tx, lambda: torch._functorch.eager_transforms.exit_jvp_nesting()
)
self.state.proxy = tx.output.create_node(
self.proxy = tx.output.create_node(
"call_function",
torch._C._functorch._jvp_increment_nesting,
(),
@ -306,7 +290,7 @@ class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable):
return variables.ConstantVariable.create(jvp_level)
def exit(self, tx: "InstructionTranslator", *args):
self.state.cleanup()
self.cleanup()
tx.output.create_node(
"call_function", torch._C._functorch._jvp_decrement_nesting, (), {}
)
@ -332,7 +316,7 @@ class SetFwdGradEnabledContextManager(ContextWrappingVariable):
tx,
lambda: torch._C._set_fwd_grad_enabled(self.prev_state),
)
self.state.proxy = tx.output.create_node(
self.proxy = tx.output.create_node(
"call_function",
torch._C._set_fwd_grad_enabled,
(mode,),
@ -341,7 +325,7 @@ class SetFwdGradEnabledContextManager(ContextWrappingVariable):
return variables.ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
self.state.cleanup()
self.cleanup()
tx.output.create_node(
"call_function",
torch._C._set_fwd_grad_enabled,
@ -370,7 +354,7 @@ class DualLevelContextManager(ContextWrappingVariable):
self.set_cleanup_hook(
tx, lambda: torch.autograd.forward_ad.exit_dual_level(level=self.new_level)
)
self.state.proxy = tx.output.create_node(
self.proxy = tx.output.create_node(
"call_function",
torch._C._enter_dual_level,
(),
@ -379,7 +363,7 @@ class DualLevelContextManager(ContextWrappingVariable):
return variables.ConstantVariable.create(self.new_level)
def exit(self, tx: "InstructionTranslator", *args):
self.state.cleanup()
self.cleanup()
tx.output.create_node(
"call_function",
torch._C._exit_dual_level,
@ -412,7 +396,7 @@ class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable):
install_guard(self._guards_singleton)
grad_level = torch._C._functorch._grad_increment_nesting()
self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting())
self.state.proxy = tx.output.create_node(
self.proxy = tx.output.create_node(
"call_function",
torch._C._functorch._grad_increment_nesting,
(),
@ -421,7 +405,7 @@ class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable):
return variables.ConstantVariable.create(grad_level)
def exit(self, tx: "InstructionTranslator", *args):
self.state.cleanup()
self.cleanup()
tx.output.create_node(
"call_function", torch._C._functorch._grad_decrement_nesting, (), {}
)
@ -492,7 +476,7 @@ class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
batch_size_value, randomness
)
self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting())
self.state.proxy = tx.output.create_node(
self.proxy = tx.output.create_node(
"call_function",
torch._C._functorch._vmap_increment_nesting,
(batch_size_node, randomness),
@ -501,7 +485,7 @@ class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
return variables.ConstantVariable.create(vmap_level)
def exit(self, tx: "InstructionTranslator", *args):
self.state.cleanup()
self.cleanup()
tx.output.create_node(
"call_function", torch._C._functorch._vmap_decrement_nesting, (), {}
)
@ -589,11 +573,11 @@ class InferenceModeVariable(ContextWrappingVariable):
self.target_values = target_values
def exit(self, tx: "InstructionTranslator", *args):
self.state.cleanup_assert()
self.cleanup_assert()
tx.output.create_node(
"call_function",
torch.autograd.grad_mode._exit_inference_mode,
(self.state.proxy,),
(self.proxy,),
{},
)
@ -619,7 +603,7 @@ class InferenceModeVariable(ContextWrappingVariable):
torch.autograd.grad_mode._exit_inference_mode(ctx)
self.set_cleanup_hook(tx, cleanup_hook)
self.state.proxy = tx.output.create_node(
self.proxy = tx.output.create_node(
"call_function",
torch.autograd.grad_mode._enter_inference_mode,
(*self.target_values,),
@ -657,11 +641,11 @@ class CUDADeviceVariable(ContextWrappingVariable):
self.target_values = target_values
def exit(self, tx: "InstructionTranslator", *args):
self.state.cleanup_assert()
self.cleanup_assert()
tx.output.create_node(
"call_function",
torch.cuda._maybe_exchange_device,
(self.state.proxy,),
(self.proxy,),
{},
)
return variables.ConstantVariable.create(False)
@ -669,7 +653,7 @@ class CUDADeviceVariable(ContextWrappingVariable):
def enter(self, tx):
prev_idx = torch.cuda._exchange_device(*self.target_values)
self.set_cleanup_hook(tx, lambda: torch.cuda._maybe_exchange_device(prev_idx))
self.state.proxy = tx.output.create_node(
self.proxy = tx.output.create_node(
"call_function",
torch.cuda._exchange_device,
(*self.target_values,),
@ -730,8 +714,8 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
self.initial_torch_function_subclass_enabled
)
self.state.cleanup_fn = fn
tx.output.add_cleanup_hook(self.state.cleanup)
self.cleanup_fn = fn
tx.output.add_cleanup_hook(self.cleanup)
def _call_func(self, tx: "InstructionTranslator", values):
assert len(values) == 0
@ -885,15 +869,15 @@ class AutocastModeVariable(ContextWrappingVariable):
self.target_values = target_values
def exit(self, tx: "InstructionTranslator", *args):
self.state.cleanup_assert()
self.cleanup_assert()
tx.output.create_node(
"call_function", torch.amp._exit_autocast, (self.state.proxy,), {}
"call_function", torch.amp._exit_autocast, (self.proxy,), {}
)
def enter(self, tx):
ctx = torch.amp._enter_autocast(*self.target_values)
self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx))
self.state.proxy = tx.output.create_node(
self.proxy = tx.output.create_node(
"call_function", torch.amp._enter_autocast, (*self.target_values,), {}
)
@ -1021,7 +1005,7 @@ class StreamContextVariable(ContextWrappingVariable):
(self.initial_values[0].as_proxy(),),
{},
)
self.state.cleanup_assert()
self.cleanup_assert()
class PreserveVersionContextVariable(ContextWrappingVariable):
@ -1212,7 +1196,7 @@ class SDPAKernelVariable(ContextWrappingVariable):
return variables.ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args):
self.state.cleanup_assert()
self.cleanup_assert()
arg = self._backends_to_nodes(tx, self.prev_backends)
tx.output.create_node(
"call_function",