mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo] Cleanup state management for ctx managers (#149689)
Removes state indirection for ctx managers. This isn't needed anymore since VTs are mutable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149689 Approved by: https://github.com/StrongerXi
This commit is contained in:
parent
cfc08caea9
commit
34743678b9
|
|
@ -20,11 +20,10 @@ consistency between eager execution and compiled graph behavior by capturing and
|
|||
restoring state changes.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch._C
|
||||
from torch._guards import Guard
|
||||
|
|
@ -54,27 +53,6 @@ if TYPE_CHECKING:
|
|||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ContextManagerState:
|
||||
"""
|
||||
Mutating `self` in VariableTracker is not allowed because we copy
|
||||
them. This is a mutable container pointed to by context managers
|
||||
that won't get copied, so it is safe to mutate.
|
||||
"""
|
||||
|
||||
cleanup_fn: Optional[Callable] = None
|
||||
proxy: Optional[torch.fx.Proxy] = None
|
||||
|
||||
def cleanup(self):
|
||||
if self.cleanup_fn is not None:
|
||||
self.cleanup_fn()
|
||||
self.cleanup_fn = None
|
||||
|
||||
def cleanup_assert(self):
|
||||
assert self.cleanup_fn, "multiple exits?"
|
||||
self.cleanup()
|
||||
|
||||
|
||||
class ContextWrappingVariable(VariableTracker):
|
||||
_nonvar_fields = {
|
||||
"cm_obj",
|
||||
|
|
@ -84,13 +62,10 @@ class ContextWrappingVariable(VariableTracker):
|
|||
*VariableTracker._nonvar_fields,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, target_values, initial_values=None, *, state=None, **kwargs
|
||||
) -> None:
|
||||
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.target_values = target_values
|
||||
self.initial_values = initial_values
|
||||
self.state = ContextManagerState() if state is None else state
|
||||
|
||||
def enter(self, tx):
|
||||
self._call_func(tx, self.target_values)
|
||||
|
|
@ -103,11 +78,11 @@ class ContextWrappingVariable(VariableTracker):
|
|||
def fn():
|
||||
self._call_func(tx, self.initial_values)
|
||||
|
||||
self.state.cleanup_fn = fn
|
||||
tx.output.add_cleanup_hook(self.state.cleanup)
|
||||
self.cleanup_fn = fn
|
||||
tx.output.add_cleanup_hook(self.cleanup)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
self.state.cleanup_assert()
|
||||
self.cleanup_assert()
|
||||
return variables.ConstantVariable.create(None)
|
||||
|
||||
def reconstruct_type(self, codegen):
|
||||
|
|
@ -152,6 +127,15 @@ class ContextWrappingVariable(VariableTracker):
|
|||
def exit_on_graph_break(self):
|
||||
return True
|
||||
|
||||
def cleanup(self):
|
||||
if self.cleanup_fn is not None:
|
||||
self.cleanup_fn()
|
||||
self.cleanup_fn = None
|
||||
|
||||
def cleanup_assert(self):
|
||||
assert self.cleanup_fn, "multiple exits?"
|
||||
self.cleanup()
|
||||
|
||||
|
||||
class GenericContextWrappingVariable(UserDefinedObjectVariable):
|
||||
# Some methods in ContextWrappingVariable assumes the arguments are
|
||||
|
|
@ -217,7 +201,7 @@ class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
|
|||
self.prev_state
|
||||
),
|
||||
)
|
||||
self.state.proxy = tx.output.create_node(
|
||||
self.proxy = tx.output.create_node(
|
||||
"call_function",
|
||||
torch._C._functorch.set_inplace_requires_grad_allowed,
|
||||
(enabled,),
|
||||
|
|
@ -226,7 +210,7 @@ class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
|
|||
return variables.ConstantVariable.create(None)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
self.state.cleanup()
|
||||
self.cleanup()
|
||||
tx.output.create_node(
|
||||
"call_function",
|
||||
torch._C._functorch.set_inplace_requires_grad_allowed,
|
||||
|
|
@ -253,7 +237,7 @@ class TemporarilyPopInterpreterStackCtxManagerVariable(ContextWrappingVariable):
|
|||
tx,
|
||||
lambda: torch._C._functorch.push_dynamic_layer_stack(self.saved),
|
||||
)
|
||||
self.state.proxy = tx.output.create_node(
|
||||
self.proxy = tx.output.create_node(
|
||||
"call_function",
|
||||
torch._C._functorch.pop_dynamic_layer_stack,
|
||||
(),
|
||||
|
|
@ -262,11 +246,11 @@ class TemporarilyPopInterpreterStackCtxManagerVariable(ContextWrappingVariable):
|
|||
return variables.ConstantVariable.create(None)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
self.state.cleanup()
|
||||
self.cleanup()
|
||||
tx.output.create_node(
|
||||
"call_function",
|
||||
torch._C._functorch.push_dynamic_layer_stack,
|
||||
(self.state.proxy,),
|
||||
(self.proxy,),
|
||||
{},
|
||||
)
|
||||
return variables.ConstantVariable.create(None)
|
||||
|
|
@ -297,7 +281,7 @@ class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable):
|
|||
self.set_cleanup_hook(
|
||||
tx, lambda: torch._functorch.eager_transforms.exit_jvp_nesting()
|
||||
)
|
||||
self.state.proxy = tx.output.create_node(
|
||||
self.proxy = tx.output.create_node(
|
||||
"call_function",
|
||||
torch._C._functorch._jvp_increment_nesting,
|
||||
(),
|
||||
|
|
@ -306,7 +290,7 @@ class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable):
|
|||
return variables.ConstantVariable.create(jvp_level)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
self.state.cleanup()
|
||||
self.cleanup()
|
||||
tx.output.create_node(
|
||||
"call_function", torch._C._functorch._jvp_decrement_nesting, (), {}
|
||||
)
|
||||
|
|
@ -332,7 +316,7 @@ class SetFwdGradEnabledContextManager(ContextWrappingVariable):
|
|||
tx,
|
||||
lambda: torch._C._set_fwd_grad_enabled(self.prev_state),
|
||||
)
|
||||
self.state.proxy = tx.output.create_node(
|
||||
self.proxy = tx.output.create_node(
|
||||
"call_function",
|
||||
torch._C._set_fwd_grad_enabled,
|
||||
(mode,),
|
||||
|
|
@ -341,7 +325,7 @@ class SetFwdGradEnabledContextManager(ContextWrappingVariable):
|
|||
return variables.ConstantVariable.create(None)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
self.state.cleanup()
|
||||
self.cleanup()
|
||||
tx.output.create_node(
|
||||
"call_function",
|
||||
torch._C._set_fwd_grad_enabled,
|
||||
|
|
@ -370,7 +354,7 @@ class DualLevelContextManager(ContextWrappingVariable):
|
|||
self.set_cleanup_hook(
|
||||
tx, lambda: torch.autograd.forward_ad.exit_dual_level(level=self.new_level)
|
||||
)
|
||||
self.state.proxy = tx.output.create_node(
|
||||
self.proxy = tx.output.create_node(
|
||||
"call_function",
|
||||
torch._C._enter_dual_level,
|
||||
(),
|
||||
|
|
@ -379,7 +363,7 @@ class DualLevelContextManager(ContextWrappingVariable):
|
|||
return variables.ConstantVariable.create(self.new_level)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
self.state.cleanup()
|
||||
self.cleanup()
|
||||
tx.output.create_node(
|
||||
"call_function",
|
||||
torch._C._exit_dual_level,
|
||||
|
|
@ -412,7 +396,7 @@ class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable):
|
|||
install_guard(self._guards_singleton)
|
||||
grad_level = torch._C._functorch._grad_increment_nesting()
|
||||
self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting())
|
||||
self.state.proxy = tx.output.create_node(
|
||||
self.proxy = tx.output.create_node(
|
||||
"call_function",
|
||||
torch._C._functorch._grad_increment_nesting,
|
||||
(),
|
||||
|
|
@ -421,7 +405,7 @@ class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable):
|
|||
return variables.ConstantVariable.create(grad_level)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
self.state.cleanup()
|
||||
self.cleanup()
|
||||
tx.output.create_node(
|
||||
"call_function", torch._C._functorch._grad_decrement_nesting, (), {}
|
||||
)
|
||||
|
|
@ -492,7 +476,7 @@ class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
|
|||
batch_size_value, randomness
|
||||
)
|
||||
self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting())
|
||||
self.state.proxy = tx.output.create_node(
|
||||
self.proxy = tx.output.create_node(
|
||||
"call_function",
|
||||
torch._C._functorch._vmap_increment_nesting,
|
||||
(batch_size_node, randomness),
|
||||
|
|
@ -501,7 +485,7 @@ class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
|
|||
return variables.ConstantVariable.create(vmap_level)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
self.state.cleanup()
|
||||
self.cleanup()
|
||||
tx.output.create_node(
|
||||
"call_function", torch._C._functorch._vmap_decrement_nesting, (), {}
|
||||
)
|
||||
|
|
@ -589,11 +573,11 @@ class InferenceModeVariable(ContextWrappingVariable):
|
|||
self.target_values = target_values
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
self.state.cleanup_assert()
|
||||
self.cleanup_assert()
|
||||
tx.output.create_node(
|
||||
"call_function",
|
||||
torch.autograd.grad_mode._exit_inference_mode,
|
||||
(self.state.proxy,),
|
||||
(self.proxy,),
|
||||
{},
|
||||
)
|
||||
|
||||
|
|
@ -619,7 +603,7 @@ class InferenceModeVariable(ContextWrappingVariable):
|
|||
torch.autograd.grad_mode._exit_inference_mode(ctx)
|
||||
|
||||
self.set_cleanup_hook(tx, cleanup_hook)
|
||||
self.state.proxy = tx.output.create_node(
|
||||
self.proxy = tx.output.create_node(
|
||||
"call_function",
|
||||
torch.autograd.grad_mode._enter_inference_mode,
|
||||
(*self.target_values,),
|
||||
|
|
@ -657,11 +641,11 @@ class CUDADeviceVariable(ContextWrappingVariable):
|
|||
self.target_values = target_values
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
self.state.cleanup_assert()
|
||||
self.cleanup_assert()
|
||||
tx.output.create_node(
|
||||
"call_function",
|
||||
torch.cuda._maybe_exchange_device,
|
||||
(self.state.proxy,),
|
||||
(self.proxy,),
|
||||
{},
|
||||
)
|
||||
return variables.ConstantVariable.create(False)
|
||||
|
|
@ -669,7 +653,7 @@ class CUDADeviceVariable(ContextWrappingVariable):
|
|||
def enter(self, tx):
|
||||
prev_idx = torch.cuda._exchange_device(*self.target_values)
|
||||
self.set_cleanup_hook(tx, lambda: torch.cuda._maybe_exchange_device(prev_idx))
|
||||
self.state.proxy = tx.output.create_node(
|
||||
self.proxy = tx.output.create_node(
|
||||
"call_function",
|
||||
torch.cuda._exchange_device,
|
||||
(*self.target_values,),
|
||||
|
|
@ -730,8 +714,8 @@ class TorchFunctionDisableVariable(ContextWrappingVariable):
|
|||
self.initial_torch_function_subclass_enabled
|
||||
)
|
||||
|
||||
self.state.cleanup_fn = fn
|
||||
tx.output.add_cleanup_hook(self.state.cleanup)
|
||||
self.cleanup_fn = fn
|
||||
tx.output.add_cleanup_hook(self.cleanup)
|
||||
|
||||
def _call_func(self, tx: "InstructionTranslator", values):
|
||||
assert len(values) == 0
|
||||
|
|
@ -885,15 +869,15 @@ class AutocastModeVariable(ContextWrappingVariable):
|
|||
self.target_values = target_values
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
self.state.cleanup_assert()
|
||||
self.cleanup_assert()
|
||||
tx.output.create_node(
|
||||
"call_function", torch.amp._exit_autocast, (self.state.proxy,), {}
|
||||
"call_function", torch.amp._exit_autocast, (self.proxy,), {}
|
||||
)
|
||||
|
||||
def enter(self, tx):
|
||||
ctx = torch.amp._enter_autocast(*self.target_values)
|
||||
self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx))
|
||||
self.state.proxy = tx.output.create_node(
|
||||
self.proxy = tx.output.create_node(
|
||||
"call_function", torch.amp._enter_autocast, (*self.target_values,), {}
|
||||
)
|
||||
|
||||
|
|
@ -1021,7 +1005,7 @@ class StreamContextVariable(ContextWrappingVariable):
|
|||
(self.initial_values[0].as_proxy(),),
|
||||
{},
|
||||
)
|
||||
self.state.cleanup_assert()
|
||||
self.cleanup_assert()
|
||||
|
||||
|
||||
class PreserveVersionContextVariable(ContextWrappingVariable):
|
||||
|
|
@ -1212,7 +1196,7 @@ class SDPAKernelVariable(ContextWrappingVariable):
|
|||
return variables.ConstantVariable.create(None)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args):
|
||||
self.state.cleanup_assert()
|
||||
self.cleanup_assert()
|
||||
arg = self._backends_to_nodes(tx, self.prev_backends)
|
||||
tx.output.create_node(
|
||||
"call_function",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user