# mypy: ignore-errors import dataclasses import inspect import sys import warnings from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union import torch._C from torch._guards import Guard from .. import variables from ..bytecode_transformation import ( create_call_function, create_instruction, create_setup_with, ) from ..device_interface import get_interface_for_device from ..exc import unimplemented, Unsupported from ..guards import GuardBuilder, install_guard from ..source import AttrSource, GlobalStateSource from .base import VariableTracker from .functions import ( NestedUserFunctionVariable, UserFunctionVariable, UserMethodVariable, WrappedUserFunctionVariable, WrappedUserMethodVariable, ) from .user_defined import UserDefinedObjectVariable if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator @dataclasses.dataclass class ContextMangerState: """ Mutating `self` in VariableTracker is not allowed because we copy them. This is a mutable container pointed to by context managers that won't get copied, so it is safe to mutate. """ cleanup_fn: Optional[Callable] = None proxy: Optional[torch.fx.Proxy] = None def cleanup(self): if self.cleanup_fn is not None: self.cleanup_fn() self.cleanup_fn = None def cleanup_assert(self): assert self.cleanup_fn, "multiple exits?" self.cleanup() class ContextWrappingVariable(VariableTracker): _nonvar_fields = { "cm_obj", "target_values", "initial_values", "state", *VariableTracker._nonvar_fields, } def __init__( self, target_values, initial_values=None, *, state=None, **kwargs ) -> None: super().__init__(**kwargs) self.target_values = target_values self.initial_values = initial_values self.state = ContextMangerState() if state is None else state def enter(self, tx): self._call_func(tx, self.target_values) self.set_cleanup_hook(tx) return variables.ConstantVariable.create(None) def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None): if fn is None: def fn(): self._call_func(tx, self.initial_values) self.state.cleanup_fn = fn tx.output.add_cleanup_hook(self.state.cleanup) def exit(self, tx: "InstructionTranslator", *args): self.state.cleanup_assert() return variables.ConstantVariable.create(None) def reconstruct_type(self, codegen): codegen( AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name()) ) def reconstruct(self, codegen): codegen.add_push_null(lambda: self.reconstruct_type(codegen)) target_values = self.target_values if not target_values: target_values = () codegen.extend_output([codegen.create_load_const(val) for val in target_values]) codegen.extend_output(create_call_function(len(target_values), False)) def module_name(self): raise NotImplementedError("module_name called on base") def fn_name(self): raise NotImplementedError("fn_name called on base") def call_function( self, tx: "InstructionTranslator", args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": assert len(args) == 1 if isinstance(args[0], NestedUserFunctionVariable): args[0] = UserFunctionVariable(args[0].get_function()) assert isinstance(args[0], (UserMethodVariable, UserFunctionVariable)) if isinstance(args[0], UserMethodVariable): return WrappedUserMethodVariable(args[0], self) if isinstance(args[0], UserFunctionVariable): return WrappedUserFunctionVariable(args[0], self) class GenericContextWrappingVariable(UserDefinedObjectVariable): # Some methods in ContextWrappingVariable assumes the arguments are # python contants. Which might not always be the case here. def __init__(self, cm_obj, **kwargs) -> None: assert cm_obj is not None super().__init__( value=cm_obj, value_type=cm_obj.__class__, **kwargs, ) self.cm_obj = cm_obj def module_name(self): return self.cm_obj.__module__ def fn_name(self): return type(self.cm_obj).__name__ def enter(self, tx): source = None if self.source is None else AttrSource(self.source, "__enter__") try: return variables.UserMethodVariable( self.cm_obj.__enter__.__func__, self, source=source, ).call_function(tx, [], {}) except Unsupported as e: unimplemented( f"Unsupported context manager {self.cm_obj}'s __enter__ function", from_exc=e, ) def exit(self, tx: "InstructionTranslator", *args): source = None if self.source is None else AttrSource(self.source, "__exit__") try: x = variables.UserMethodVariable( self.cm_obj.__exit__.__func__, self, source=source, ).call_function( tx, [ variables.ConstantVariable.create(None), variables.ConstantVariable.create(None), variables.ConstantVariable.create(None), ], {}, ) except Unsupported as e: unimplemented( f"Unsupported context manager {self.cm_obj}'s __exit__ function", from_exc=e, ) tx.generic_context_manager_depth -= 1 return x class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): """represents torch grad requries grad""" @staticmethod def create(tx: "InstructionTranslator", target_values, **kwargs): return GradInplaceRequiresGradCtxManagerVariable( target_values=target_values, initial_values=None, **kwargs, ) def enter(self, tx): [enabled] = self.target_values self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed() torch._C._functorch.set_inplace_requires_grad_allowed(enabled) self.set_cleanup_hook( tx, lambda: torch._C._functorch.set_inplace_requires_grad_allowed( self.prev_state ), ) self.state.proxy = tx.output.create_node( "call_function", torch._C._functorch.set_inplace_requires_grad_allowed, (enabled,), {}, ) return variables.ConstantVariable.create(None) def exit(self, tx: "InstructionTranslator", *args): self.state.cleanup() tx.output.create_node( "call_function", torch._C._functorch.set_inplace_requires_grad_allowed, (self.prev_state,), {}, ) return variables.ConstantVariable.create(None) class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable): """represents torch.func.jvp increment/decrement nesting""" # A guard is needed as the grad level is baked into the torch FX graph # This is fine if jvp is only called from within the function # being compiled. But the FX graph may be invalid in the case of a jvp # call from eager that calls the compiled function, as the jvp levels # may be different. _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) @staticmethod def create(tx: "InstructionTranslator", **kwargs): var = JvpIncrementNestingCtxManagerVariable( target_values=None, initial_values=None, **kwargs, ) return var def enter(self, tx): install_guard(self._guards_singleton) jvp_level = torch._functorch.eager_transforms.enter_jvp_nesting() self.set_cleanup_hook( tx, lambda: torch._functorch.eager_transforms.exit_jvp_nesting() ) self.state.proxy = tx.output.create_node( "call_function", torch._C._functorch._jvp_increment_nesting, (), {}, ) return variables.ConstantVariable.create(jvp_level) def exit(self, tx: "InstructionTranslator", *args): self.state.cleanup() tx.output.create_node( "call_function", torch._C._functorch._jvp_decrement_nesting, (), {} ) return variables.ConstantVariable.create(None) class SetFwdGradEnabledContextManager(ContextWrappingVariable): """represents torch.autograd.forward_ad._set_fwd_grad_enabled() to enable/disable fwd grad""" @staticmethod def create(tx: "InstructionTranslator", target_values, **kwargs): return SetFwdGradEnabledContextManager( target_values=target_values, initial_values=None, **kwargs, ) def enter(self, tx): [mode] = self.target_values self.prev_state = torch._C._is_fwd_grad_enabled() torch._C._set_fwd_grad_enabled(mode) self.set_cleanup_hook( tx, lambda: torch._C._set_fwd_grad_enabled(self.prev_state), ) self.state.proxy = tx.output.create_node( "call_function", torch._C._set_fwd_grad_enabled, (mode,), {}, ) return variables.ConstantVariable.create(None) def exit(self, tx: "InstructionTranslator", *args): self.state.cleanup() tx.output.create_node( "call_function", torch._C._set_fwd_grad_enabled, (self.prev_state,), {}, ) return variables.ConstantVariable.create(None) class DualLevelContextManager(ContextWrappingVariable): """Represents torch.autograd.forward_ad.dual_level ctx manager""" _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL) @staticmethod def create(tx: "InstructionTranslator", **kwargs): return DualLevelContextManager( target_values=None, initial_values=None, **kwargs, ) def enter(self, tx): install_guard(self._guards_singleton) self.new_level = torch.autograd.forward_ad.enter_dual_level() self.set_cleanup_hook( tx, lambda: torch.autograd.forward_ad.exit_dual_level(level=self.new_level) ) self.state.proxy = tx.output.create_node( "call_function", torch._C._enter_dual_level, (), {}, ) return variables.ConstantVariable.create(self.new_level) def exit(self, tx: "InstructionTranslator", *args): self.state.cleanup() tx.output.create_node( "call_function", torch._C._exit_dual_level, (self.new_level,), {}, ) return variables.ConstantVariable.create(None) class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable): """represents torch.func.grad increment/decrement nesting""" # A guard is needed as the grad level is baked into the torch FX graph # This is fine if grad is only called from within the function # being compiled. But the FX graph may be invalid in the case of a grad # call from eager that calls the compiled function, as the grad levels # may be different. _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) @staticmethod def create(tx: "InstructionTranslator", **kwargs): var = GradIncrementNestingCtxManagerVariable( target_values=None, initial_values=None, **kwargs, ) return var def enter(self, tx): install_guard(self._guards_singleton) grad_level = torch._C._functorch._grad_increment_nesting() self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting()) self.state.proxy = tx.output.create_node( "call_function", torch._C._functorch._grad_increment_nesting, (), {}, ) return variables.ConstantVariable.create(grad_level) def exit(self, tx: "InstructionTranslator", *args): self.state.cleanup() tx.output.create_node( "call_function", torch._C._functorch._grad_decrement_nesting, (), {} ) return variables.ConstantVariable.create(None) class CatchWarningsCtxManagerVariable(ContextWrappingVariable): """Delay a call to warnings.catch_warnings""" @staticmethod def create(tx: "InstructionTranslator", catch_warnings_args): return CatchWarningsCtxManagerVariable( catch_warnings_args=catch_warnings_args, target_values=None, initial_values=None, ) def __init__(self, catch_warnings_args, **kwargs) -> None: assert isinstance(catch_warnings_args, dict), catch_warnings_args super().__init__(**kwargs) self.catch_warnings_args = catch_warnings_args def enter(self, tx): kwargs = { k: v.as_python_constant() for k, v in self.catch_warnings_args.items() } ctx_val = warnings.catch_warnings(**kwargs) self.set_cleanup_hook(tx, lambda: ctx_val.__exit__(None, None, None)) return variables.ConstantVariable.create(ctx_val.__enter__()) def reconstruct(self, cg): cg.add_push_null(lambda: cg.load_import_from("warnings", "catch_warnings")) cg.foreach(self.catch_warnings_args.values()) keys = tuple(self.catch_warnings_args.keys()) cg.extend_output(cg.create_call_function_kw(len(keys), keys, False)) class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable): """represents torch VMap increment/decrement nesting""" # A guard is needed as the vmap level is baked into the torch FX graph # generated. This is fine if vmap is only called from within the function # being compiled. But the FX graph may be invalid in the case of a vmap # call from eager that calls the compiled function, as the vmap levels # may be different. _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) @staticmethod def create(tx: "InstructionTranslator", target_values, **kwargs): var = VmapIncrementNestingCtxManagerVariable( target_values=target_values, initial_values=None, **kwargs, ) return var def enter(self, tx): install_guard(self._guards_singleton) batch_size, randomness = self.target_values vmap_level = torch._C._functorch._vmap_increment_nesting(batch_size, randomness) self.set_cleanup_hook(tx, lambda: torch._C._functorch._vmap_decrement_nesting()) self.state.proxy = tx.output.create_node( "call_function", torch._C._functorch._vmap_increment_nesting, (batch_size, randomness), {}, ) return variables.ConstantVariable.create(vmap_level) def exit(self, tx: "InstructionTranslator", *args): self.state.cleanup() tx.output.create_node( "call_function", torch._C._functorch._vmap_decrement_nesting, (), {} ) return variables.ConstantVariable.create(None) class GradModeVariable(ContextWrappingVariable): """represents torch.{no_grad,enable_grad,set_grad_mode}()""" _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE) @staticmethod def create(tx: "InstructionTranslator", target_value, initialized=False, **kwargs): var = GradModeVariable( target_values=[target_value], initial_values=[torch.is_grad_enabled()], **kwargs, ) if initialized: var._call_func(tx, var.target_values) return var def __init__( self, target_values, initial_values=None, initialized=True, **kwargs ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) install_guard(self._guards_singleton) def enter(self, tx): self._call_func(tx, self.target_values) return variables.ConstantVariable.create(None) def exit(self, tx: "InstructionTranslator", *args): self._call_func(tx, self.initial_values) return variables.ConstantVariable.create(None) def call_function( self, tx: "InstructionTranslator", args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ): self._call_func(tx, self.initial_values) # undo eager initialization return super().call_function(tx, args, kwargs) def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 value = values[0] # Coalesce grad mode mutations if torch.is_grad_enabled() != value: tx.output.create_node( "call_function", torch._C._set_grad_enabled, (value,), {} ) torch._C._set_grad_enabled(value) def module_name(self): return "torch" def fn_name(self): return "set_grad_enabled" class InferenceModeVariable(ContextWrappingVariable): @staticmethod def create(tx: "InstructionTranslator", target_value, **kwargs): var = InferenceModeVariable( [target_value], initial_values=torch.is_inference_mode_enabled(), **kwargs ) return var def __init__( self, target_values, initial_values=None, **kwargs, ) -> None: if initial_values is None: # This must be called here since function defaults are evaluated at import time initial_values = torch.is_inference_mode_enabled() super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) self.target_values = target_values def exit(self, tx: "InstructionTranslator", *args): self.state.cleanup_assert() tx.output.create_node( "call_function", torch.autograd.grad_mode._exit_inference_mode, (self.state.proxy,), {}, ) def enter(self, tx): ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values) self.set_cleanup_hook( tx, lambda: torch.autograd.grad_mode._exit_inference_mode(ctx) ) self.state.proxy = tx.output.create_node( "call_function", torch.autograd.grad_mode._enter_inference_mode, (*self.target_values,), {}, ) def module_name(self): return "torch" def fn_name(self): return "inference_mode" class CUDADeviceVariable(ContextWrappingVariable): """represents torch.cuda.device""" @staticmethod def create(tx: "InstructionTranslator", device, **kwargs): var = CUDADeviceVariable( target_values=[torch.cuda._get_device_index(device, optional=True)], initial_values=None, **kwargs, ) return var def __init__( self, target_values, initial_values=None, **kwargs, ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) self.target_values = target_values def exit(self, tx: "InstructionTranslator", *args): self.state.cleanup_assert() tx.output.create_node( "call_function", torch.cuda._maybe_exchange_device, (self.state.proxy,), {}, ) return variables.ConstantVariable.create(False) def enter(self, tx): prev_idx = torch.cuda._exchange_device(*self.target_values) self.set_cleanup_hook(tx, lambda: torch.cuda._maybe_exchange_device(prev_idx)) self.state.proxy = tx.output.create_node( "call_function", torch.cuda._exchange_device, (*self.target_values,), {}, ) def module_name(self): return "torch.cuda" def fn_name(self): return "device" class TorchFunctionDisableVariable(ContextWrappingVariable): """represents whether torch function overrides are enabled or not""" _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE) @staticmethod def create(tx: "InstructionTranslator", **kwargs): var = TorchFunctionDisableVariable( target_values=[False], initial_values=[tx.output.torch_function_enabled], **kwargs, ) # mlazos: I think this is here to make sure we don't reinvoke on clone() var._call_func(tx, [False]) var.set_cleanup_hook(tx) return var def __init__(self, target_values, initial_values=None, **kwargs) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) install_guard(self._guards_singleton) def enter(self, tx): return variables.ConstantVariable.create(None) def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 tx.output.set_torch_function_state(values[0]) class DeterministicAlgorithmsVariable(ContextWrappingVariable): """represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()""" _guards_singleton = Guard( GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS ) @staticmethod def create(tx: "InstructionTranslator", target_value, **kwargs): var = DeterministicAlgorithmsVariable( target_values=[target_value], initial_values=[torch.are_deterministic_algorithms_enabled()], **kwargs, ) var._call_func(tx, [target_value]) var.set_cleanup_hook(tx) return var def __init__(self, target_values, initial_values=None, **kwargs) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) install_guard(self._guards_singleton) def enter(self, tx): return variables.ConstantVariable.create(None) def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 value = values[0] tx.output.create_node( "call_function", torch._C._set_deterministic_algorithms, (value,), {} ), torch._C._set_deterministic_algorithms(value) def module_name(self): return "torch" def fn_name(self): return "use_deterministic_algorithms" class DisabledSavedTensorsHooksVariable(ContextWrappingVariable): """represents torch.autograd.graph.disable_saved_tensors_hook.""" @staticmethod def create(tx: "InstructionTranslator", target_value, **kwargs): var = DisabledSavedTensorsHooksVariable( target_values=[target_value], initial_values=[ torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() ], **kwargs, ) var._call_func(tx, [target_value]) var.set_cleanup_hook(tx) return var def __init__(self, target_values, initial_values=None, **kwargs) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) def enter(self, tx): return variables.ConstantVariable.create(None) def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 value = values[0] if value is not None: # Disable `saved_tensors_hooks` with message (`value`) # OR # we are exiting this context and restoring the previous message. tx.output.create_node( "call_function", torch._C._autograd._saved_tensors_hooks_disable, (value,), {}, ) torch._C._autograd._saved_tensors_hooks_disable(value) else: # We are exiting this context and if prev_message was None, we re-enable `saved_tensors_hooks`. tx.output.create_node( "call_function", torch._C._autograd._saved_tensors_hooks_enable, (), {} ) torch._C._autograd._saved_tensors_hooks_enable() def module_name(self): return "torch.autograd.graph" def fn_name(self): return "disable_saved_tensors_hooks" class AutocastModeVariable(ContextWrappingVariable): @staticmethod def create(func, args, kwargs): assert func in [ torch.amp.autocast_mode.autocast, torch.cuda.amp.autocast, torch.cpu.amp.autocast, ] # device_type : str, # dtype : Optional[_dtype] = None, # enabled : bool = True, # cache_enabled : Optional[bool] = None):cache_enabled bound_args = inspect.signature(func).bind(*args, **kwargs) bound_args.apply_defaults() target_values = [] kwargs.clear() for key in ["device_type", "dtype", "enabled", "cache_enabled"]: if key == "device_type" and func in [ torch.cuda.amp.autocast, torch.cpu.amp.autocast, ]: arg = "cuda" if func is torch.cuda.amp.autocast else "cpu" else: arg = bound_args.arguments[key] if isinstance(arg, VariableTracker): target_values.append(arg.as_python_constant()) else: target_values.append(arg) var = AutocastModeVariable(target_values, initial_values=None, **kwargs) return var def __init__(self, target_values, initial_values=None, **kwargs) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) self.target_values = target_values def exit(self, tx: "InstructionTranslator", *args): self.state.cleanup_assert() tx.output.create_node( "call_function", torch.amp._exit_autocast, (self.state.proxy,), {} ) def enter(self, tx): ctx = torch.amp._enter_autocast(*self.target_values) self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx)) self.state.proxy = tx.output.create_node( "call_function", torch.amp._enter_autocast, (*self.target_values,), {} ) def module_name(self): return "torch.amp.autocast_mode" def fn_name(self): return "autocast" class NullContextVariable(ContextWrappingVariable): """ This class represents Python contextlib.nullcontext. It's used as a placeholder for other context managers that Dynamo doesn't support yet, e.g, torch.autograd.profiler.record_function. """ def __init__(self, target_values=None, **kwargs) -> None: super().__init__(target_values=target_values, **kwargs) def enter(self, tx): return variables.ConstantVariable.create(None) def exit(self, tx: "InstructionTranslator", *args): return variables.ConstantVariable.create(None) def module_name(self): return "contextlib" def fn_name(self): return "nullcontext" class StreamContextVariable(ContextWrappingVariable): @staticmethod def create(tx: "InstructionTranslator", target_value, **kwargs): from .builder import wrap_fx_proxy_cls current_stream_method = get_interface_for_device( target_value.device ).current_stream current_stream = wrap_fx_proxy_cls( StreamVariable, tx, tx.output.create_proxy( "call_function", current_stream_method, (None,), {}, ), ) return StreamContextVariable( target_values=[target_value], initial_values=[current_stream], device=target_value.device, **kwargs, ) def __init__(self, target_values, device, initial_values=None, **kwargs) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) self.device = device self.set_stream = get_interface_for_device(self.device).set_stream self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id def enter(self, tx): # stream generated inside the traced function if self.target_values[0].as_proxy() is not None: tx.output.create_proxy( "call_function", self.set_stream, (self.target_values[0].as_proxy(),), {}, ) # stream passed from outside the traced function else: stream = self.target_values[0].value tx.output.create_proxy( "call_function", self.set_stream_id, (stream.stream_id, stream.device_index, stream.device_type), {}, ) self.set_stream(self.target_values[0].value) self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value)) def exit(self, tx: "InstructionTranslator", *args): tx.output.create_proxy( "call_function", self.set_stream, (self.initial_values[0].as_proxy(),), {}, ) self.state.cleanup_assert() class PreserveVersionContextVariable(ContextWrappingVariable): """ Wraps torch.autograd._unsafe_preserve_version_counter """ @staticmethod def constructor(tx): return variables.LambdaVariable( lambda tensor: PreserveVersionContextVariable( tensor, tensor.var_getattr(tx, "_version"), ) ) def __init__(self, tensor, prev_version, **kwargs) -> None: kwargs.setdefault("target_values", None) super().__init__(**kwargs) self.tensor = tensor self.prev_version = prev_version def enter(self, tx): pass def exit(self, tx: "InstructionTranslator", *args): from ..tensor_version_op import _unsafe_set_version_counter return variables.TorchInGraphFunctionVariable( _unsafe_set_version_counter ).call_function(tx, [self.tensor, self.prev_version], {}) def reconstruct(self, codegen): unimplemented( "torch.autograd._unsafe_preserve_version_counter with graph break" ) class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable): _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE) @staticmethod def create(tx: "InstructionTranslator", param_group_var, target_value, **kwargs): var = FSDPParamGroupUseTrainingStateVariable( param_group_var=param_group_var, target_values=[target_value], initial_values=[param_group_var.value._training_state], **kwargs, ) return var def __init__( self, param_group_var, target_values, initial_values=None, **kwargs ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) self.param_group_var = param_group_var install_guard(self._guards_singleton) def enter(self, tx): self._call_func(tx, self.target_values) return variables.ConstantVariable.create(None) def exit(self, tx: "InstructionTranslator", *args): self._call_func(tx, self.initial_values) return variables.ConstantVariable.create(None) def call_function( self, tx: "InstructionTranslator", args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ): self._call_func(tx, self.initial_values) # undo eager initialization return super().call_function(tx, args, kwargs) def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 value = values[0] if self.param_group_var.value._training_state != value: self.param_group_var.call_method( tx, "__setattr__", ( variables.ConstantVariable.create("_training_state"), variables.EnumVariable(value), ), {}, ) self.param_group_var.value._training_state = value def module_name(self): return "torch.distributed._composable.fsdp._fsdp_param_group.FSDPParamGroup" def fn_name(self): return "use_training_state" class StreamVariable(VariableTracker): def __init__(self, proxy, value, device, **kwargs) -> None: if proxy is not None and "example_value" in proxy.node.meta: assert proxy.node.meta["example_value"] == value assert ( value.device.type == device.type ), "stream value is not equal to the passed device" super().__init__(**kwargs) self.proxy = proxy self.value = value self.device = device def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": assert hasattr(self.value, name), f"no stream method found named {name}" assert name in [ "wait_stream", "synchronize", "query", "record_event", "wait_event", ], f" unsupported stream method {name}" from ..utils import proxy_args_kwargs from .builder import wrap_fx_proxy_cls if name in ("wait_stream", "synchronize", "wait_event"): tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self] + args, kwargs) ) return variables.ConstantVariable(None) elif name == "query": return wrap_fx_proxy_cls( target_cls=variables.ConstantVariable, tx=tx, proxy=tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self] + args, kwargs) ), ) elif name == "record_event": return wrap_fx_proxy_cls( target_cls=EventVariable, tx=tx, proxy=tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self] + args, kwargs) ), ) else: unimplemented(self.device + " stream method " + name + " unsupported") def as_proxy(self): return self.proxy def reconstruct(self, codegen): # If we got here, this stream is fully subsumed by the graph - this means it is # not an input or global assert not self.source # Since we just proved that - for other such structures, like lists and dicts, reconstruction # is fine and sound according to dynamo principles of treating collectives. However, # streams are special in that we want to preserve the identity of the stream as the same as in the graph # Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not # yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending # design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there. prefix = f"_stream_{self.device}" name = codegen.tx.output.install_global_by_id(prefix, self.value) codegen.append_output(codegen.create_load_global(name, add=True)) class EventVariable(VariableTracker): def __init__(self, proxy, value, **kwargs) -> None: if proxy is not None and "example_value" in proxy.node.meta: assert proxy.node.meta["example_value"] == value super().__init__(**kwargs) self.proxy = proxy self.value = value def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": from ..utils import proxy_args_kwargs from .builder import wrap_fx_proxy_cls if name in ("wait", "record", "synchronize"): tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self] + args, kwargs) ) return variables.ConstantVariable(None) elif name == "query": return wrap_fx_proxy_cls( target_cls=variables.ConstantVariable, tx=tx, proxy=tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self] + args, kwargs) ), ) else: unimplemented(f"event method {name} unsupported") def as_proxy(self): return self.proxy def reconstruct(self, codegen): # If we got here, this event is fully subsumed by the graph - this means it is # not an input or global assert not self.source # Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there. prefix = "_event" name = codegen.tx.output.install_global_by_id(prefix, self.value) codegen.append_output(codegen.create_load_global(name, add=True)) class WithExitFunctionVariable(VariableTracker): _nonvar_fields = { "target", *VariableTracker._nonvar_fields, } def __init__( self, ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable], target, **kwargs, ) -> None: super().__init__(**kwargs) assert isinstance( ctx, (ContextWrappingVariable, GenericContextWrappingVariable) ) self.ctx = ctx self.target = target def call_function( self, tx: "InstructionTranslator", args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": assert not kwargs return self.ctx.exit(tx, *args) def reconstruct(self, codegen): # Note here we reconstruct the context manager rather than the # exit function. The handler generated by BlockStackEntry # will re-enter the context in the resume function. self.ctx.reconstruct_type(codegen) if codegen.tx.output.partial_convert: if sys.version_info >= (3, 11): codegen.append_output(create_instruction("PUSH_NULL")) if sys.version_info < (3, 13): codegen.append_output(create_instruction("SWAP", arg=2)) codegen.extend_output( [codegen.create_load_const(val) for val in self.ctx.target_values] ) codegen.extend_output( create_call_function(len(self.ctx.target_values), False) ) codegen.append_output(create_setup_with(self.target)) codegen.append_output(create_instruction("POP_TOP"))