pytorch/torch/_dynamo/side_effects.py

853 lines
35 KiB
Python

# mypy: allow-untyped-defs
import collections
import contextlib
import functools
import inspect
import warnings
import weakref
from collections.abc import MutableMapping
from types import CellType
from typing import Any, Optional
import torch.nn
from . import utils, variables
from .bytecode_transformation import (
bytecode_from_template,
create_call_function,
create_call_method,
create_instruction,
)
from .codegen import PyCodegen
from .exc import unimplemented
from .source import GlobalSource, LocalCellSource, LocalSource, Source
from .utils import dict_new, is_frozen_dataclass, nn_module_new, object_new, tuple_new
from .variables.base import (
AttributeMutation,
AttributeMutationExisting,
AttributeMutationNew,
is_side_effect_safe,
ValueMutationExisting,
ValueMutationNew,
VariableTracker,
)
from .variables.user_defined import FrozenDataClassVariable
def _manual_dict_setitem(dict_from, dict_to, mro_index):
# Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have
# to be careful because we don't want to trigger the user defined object
# setitem or clear. The mro_index is used to find the dict/OrderedDict from
# the class mro.
dict_class = type(dict_to).__mro__[mro_index]
dict_class.clear(dict_to)
for k, v in dict_from.items():
dict_class.__setitem__(dict_to, k, v)
class SideEffects:
"""
Track side effects (list mutation, setattr, etc) that need to be
applied after an FX graph is run.
"""
id_to_variable: dict[int, VariableTracker]
store_attr_mutations: dict[VariableTracker, dict[str, VariableTracker]]
keepalive: list[Any]
def __init__(
self,
output_graph,
id_to_variable=None,
store_attr_mutations=None,
keepalive=None,
save_for_backward=None,
tensor_hooks=None,
):
super().__init__()
self.output_graph_weakref = weakref.ref(output_graph)
self.id_to_variable = id_to_variable or {}
self.store_attr_mutations = store_attr_mutations or {}
self.keepalive = keepalive or []
self.save_for_backward = save_for_backward or []
self.tensor_hooks = tensor_hooks or {}
# Track Compiled Autograd final callbacks that must be called at the end of Compiled Autograd backward graph.
# Only applicable if this graph is created from Dynamo tracing in Compiled Autograd.
self.ca_final_callbacks_var = None
def __eq__(self, other: object) -> bool:
assert isinstance(other, SideEffects)
# NB: do NOT test keepalive
return (
self.id_to_variable == other.id_to_variable
and self.store_attr_mutations == other.store_attr_mutations
and self.save_for_backward == other.save_for_backward
and self.tensor_hooks == other.tensor_hooks
)
def diff(self, other: "SideEffects") -> Optional[str]:
if self.id_to_variable != other.id_to_variable:
sk_itv = self.id_to_variable.keys()
ok_itv = other.id_to_variable.keys()
if sk_itv != ok_itv:
return f"id_to_variable keys: {sk_itv} != {ok_itv}"
# Feel free to augment this with more fancy diffing logic
# if needed for debugging
return "id_to_variable: unknown diff"
elif self.store_attr_mutations != other.store_attr_mutations:
sk_sam = self.store_attr_mutations.keys()
ok_sam = other.store_attr_mutations.keys()
if sk_sam != ok_sam:
return f"store_attr_mutations keys: {sk_sam} != {ok_sam}"
return "store_attr_mutations: unknown diff"
elif self.save_for_backward != other.save_for_backward:
return "save_for_backward"
elif self.tensor_hooks != other.tensor_hooks:
return "tensor_hooks"
else:
return None
def clone(self):
"""Create a shallow copy"""
return self.__class__(
output_graph=self.output_graph_weakref(),
id_to_variable=dict(self.id_to_variable),
store_attr_mutations={
k: dict(v) for k, v in self.store_attr_mutations.items()
},
keepalive=list(self.keepalive),
save_for_backward=self.save_for_backward,
tensor_hooks=self.tensor_hooks,
)
def __contains__(self, item):
return id(item) in self.id_to_variable
def __getitem__(self, item):
return self.id_to_variable[id(item)]
def should_allow_side_effects_under_checkpoint(self):
output_graph = self.output_graph_weakref()
return (
output_graph
and output_graph.current_tx.output.current_tracer.under_activation_checkpoint
and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint
)
def check_allowed_side_effect(self, item):
from torch._dynamo.variables.misc import AutogradFunctionContextVariable
# People do things like self.dim = dim inside autograd.Function.
# These are benign.
if isinstance(item, AutogradFunctionContextVariable):
return True
if self.should_allow_side_effects_under_checkpoint():
return True
if not is_side_effect_safe(item.mutation_type):
unimplemented(
"HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)"
)
def store_attr(self, item: VariableTracker, name: str, value: VariableTracker):
assert self.is_attribute_mutation(item)
self.check_allowed_side_effect(item)
if item not in self.store_attr_mutations:
self.store_attr_mutations[item] = {}
self.store_attr_mutations[item][name] = value
def load_attr(self, item, name, deleted_ok=False, check=False):
if check:
assert self.is_attribute_mutation(item)
result = self.store_attr_mutations[item][name]
if not deleted_ok and isinstance(result, variables.DeletedVariable):
unimplemented("read deleted attribute")
return result
def store_cell(self, cellvar, value):
if cellvar.is_immutable():
unimplemented("Dynamo currently doesn't support writing to such cell")
assert isinstance(cellvar, variables.CellVariable)
assert isinstance(value, variables.VariableTracker)
self.store_attr(cellvar, "cell_contents", value)
def load_cell(self, cellvar):
assert isinstance(cellvar, variables.CellVariable)
if self.has_pending_mutation_of_attr(cellvar, "cell_contents"):
return self.load_attr(cellvar, "cell_contents", check=False)
if cellvar.pre_existing_contents:
return cellvar.pre_existing_contents
unimplemented("cannot read uninitialized cell")
def load_global(self, gvar: VariableTracker, name: str):
assert isinstance(gvar, variables.VariableTracker)
return self.load_attr(gvar, name)
def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker):
assert isinstance(gvar, variables.VariableTracker)
assert isinstance(value, variables.VariableTracker)
self.store_attr(gvar, name, value)
@staticmethod
def cls_supports_mutation_side_effects(cls):
return inspect.getattr_static(cls, "__getattribute__", None) in (
object.__getattribute__,
dict.__getattribute__,
int.__getattribute__,
str.__getattribute__,
)
def is_attribute_mutation(self, item):
return isinstance(item.mutation_type, AttributeMutation)
def has_pending_mutation(self, item):
return self.is_attribute_mutation(item) and bool(
self.store_attr_mutations.get(item)
)
def has_pending_mutation_of_attr(self, item, name):
return self.is_attribute_mutation(
item
) and name in self.store_attr_mutations.get(item, ())
def is_modified(self, item):
if item.is_immutable():
return False
if isinstance(item.mutation_type, (AttributeMutationNew, ValueMutationNew)):
return True
if self.is_attribute_mutation(item):
return item in self.store_attr_mutations
return item.mutation_type.is_modified
def _track_obj(
self,
item: Any,
variable: VariableTracker,
mutation_type_cls=ValueMutationExisting,
):
"""Start tracking a new variable for mutation"""
assert variable.source is not None
if id(item) in self.id_to_variable:
raise AssertionError(
f"{variable} is already tracked for mutation. This could be "
"because you are not using VariableBuilder to construct "
"the variable tracker. "
f"Source of new object: {variable.source}. "
f"Source of previously tracked object: {self.id_to_variable[id(item)].source}."
)
variable.mutation_type = mutation_type_cls()
self.id_to_variable[id(item)] = variable
self.keepalive.append(item)
return variable
track_mutable = _track_obj
def track_object_existing(
self,
item: Any,
variable: VariableTracker,
):
return self._track_obj(
item, variable, mutation_type_cls=AttributeMutationExisting
)
def track_object_new(
self,
cls_source: Source,
user_cls: Any,
variable_cls: Any,
options,
):
if user_cls is torch.autograd.function.FunctionCtx:
with warnings.catch_warnings(record=True):
obj = torch.autograd.Function()
elif issubclass(user_cls, torch.nn.Module):
obj = nn_module_new(user_cls)
elif issubclass(user_cls, (dict, collections.OrderedDict)):
obj = dict_new(user_cls)
elif issubclass(user_cls, tuple):
obj = tuple_new(user_cls)
else:
try:
obj = object_new(user_cls)
except TypeError:
# TODO(anijain2305/jansel) - Even though object.__new__ is same
# as user_cls.__new__, calling object.__new__(user_cls) fails
# with TypeError.
unimplemented(f"Unable to construct the object of type {user_cls}")
variable = variable_cls(
obj,
mutation_type=AttributeMutationNew(cls_source),
**options,
)
self.id_to_variable[id(obj)] = variable
self.keepalive.append(obj)
return variable
def track_object_new_from_user_defined_class(
self,
cls_variable: "variables.UserDefinedClassVariable",
):
cls_source = cls_variable.source
user_cls = cls_variable.value
# Find the variable class
variable_cls: type[
variables.UserDefinedObjectVariable
] = variables.UserDefinedObjectVariable
if issubclass(user_cls, torch.nn.Module):
variable_cls = variables.UnspecializedNNModuleVariable
elif issubclass(user_cls, (dict, collections.OrderedDict)):
variable_cls = variables.UserDefinedDictVariable
elif issubclass(user_cls, tuple):
variable_cls = variables.UserDefinedTupleVariable
elif issubclass(user_cls, MutableMapping):
variable_cls = variables.MutableMappingVariable
elif is_frozen_dataclass(user_cls):
variable_cls = FrozenDataClassVariable
else:
variable_cls = variables.UserDefinedObjectVariable
assert issubclass(variable_cls, variables.UserDefinedObjectVariable)
variable_cls = functools.partial(variable_cls, cls_source=cls_source)
return self.track_object_new(cls_source, user_cls, variable_cls, {})
def track_cell_new(
self,
):
obj = object()
variable = variables.CellVariable(
mutation_type=AttributeMutationNew(),
)
self.id_to_variable[id(obj)] = variable
self.keepalive.append(obj)
return variable
def track_cell_existing(
self, source: Optional[Source], cell: CellType, contents: VariableTracker
):
variable = variables.CellVariable(
# We don't support mutation to cell without source because we need
# source to properly codegen the mutations.
mutation_type=None if source is None else AttributeMutationExisting(),
pre_existing_contents=contents,
source=source,
)
self.id_to_variable[id(cell)] = variable
self.keepalive.append(cell)
return variable
def track_global_existing(self, source: Source, item: Any):
variable = variables.NewGlobalVariable(
mutation_type=AttributeMutationExisting(),
source=source,
)
self.id_to_variable[id(item)] = variable
self.keepalive.append(item)
return variable
def track_save_for_backward(self, ctx, args):
assert isinstance(ctx, variables.AutogradFunctionContextVariable)
self.save_for_backward.append((ctx, args))
def track_tensor_variables_from_runahead_side_effects(self, other):
# In higher order ops we want to keep track of tensors seen in the
# speculate_subgraph so that we don't lift them again as a new input in
# other speculate_subgraph or in the root tracer.
for other_item in other.keepalive:
other_id = id(other_item)
other_variable = other.id_to_variable[other_id]
if other_id not in self.id_to_variable and isinstance(
other_variable, variables.TensorVariable
):
self.track_object_existing(other_item, other_variable)
def prune_dead_object_new(self, tx):
# Avoid VT cycles from e.g., recursive function.
visited: set[VariableTracker] = set()
live_new_objects: set[VariableTracker] = set()
def visit(var: VariableTracker):
if var in visited:
return
visited.add(var)
# Object may have been mutated, store this mutation.
if isinstance(var.mutation_type, AttributeMutationNew):
live_new_objects.add(var)
# It's possible that we have mutated the value of this variable
# to be another one. The new value is in store_attr_mutations.
# Also recurse through the new value to detect alive AttributeMutationNew.
if var in self.store_attr_mutations:
VariableTracker.visit(
visit, self.store_attr_mutations[var] # noqa: F821
)
def is_live(var: VariableTracker):
if isinstance(var.mutation_type, AttributeMutationNew):
return var in live_new_objects
return True
pre_existing_vars = [
var
for var in self.id_to_variable.values()
if not isinstance(var.mutation_type, AttributeMutationNew)
]
# The only live side effects come from returns (tx.stack), any intermediates
# during a graph break (tx.symbolic_locals), and mutation on pre-existing variables.
# Recursively visit Variables and see if any of them have been mutated.
VariableTracker.visit(
visit,
# TODO track from all possible sources.
(
tx.stack,
tx.symbolic_locals,
pre_existing_vars,
tx.output.backward_state,
self.tensor_hooks,
),
)
# Manually release the self-referential function, which indirectly
# captures certain `VariableTracker` and affects parts of PT test/logic
# that are sensitive to when certain objects get released.
del visit
# NB: cell variable handling.is tricky.
# cell variables must stay alive if any NestedUserFunctionVariable
# are live. "visit"-ing the NestedUserFunctionVariable visits
# the .closures field, from which we will see if we need to keep
# any mutations to cell variables alive.
self.id_to_variable = {
k: v for k, v in self.id_to_variable.items() if is_live(v)
}
self.store_attr_mutations = {
k: v for k, v in self.store_attr_mutations.items() if is_live(k)
}
def mutation(self, var):
self.check_allowed_side_effect(var)
if isinstance(var.mutation_type, ValueMutationExisting):
var.mutation_type.is_modified = True
def _get_modified_vars(self):
return [var for var in self.id_to_variable.values() if self.is_modified(var)]
def get_new_function(self, var):
if isinstance(var, variables.UserDefinedDictVariable):
return "dict_new"
elif isinstance(var, variables.UserDefinedTupleVariable):
return "tuple_new"
return "object_new"
def codegen_save_tempvars(self, cg: PyCodegen):
# Make sure we codegen these modified VT to their source by default, so
# that mutation and aliasing are properly accounted for.
for var in self._get_modified_vars():
if isinstance(var.mutation_type, AttributeMutationNew) and isinstance(
var, variables.CellVariable
):
# Cells created in the root frame are created either by
# `MAKE_CELL` or by them being in `co_cellvars`, so we only emit
# `make_cell` for the non-root-frame cells here.
# TODO generalize this so we never need to call `make_cell`.
if var.local_name is None:
cg.add_push_null(
lambda: cg.load_import_from(utils.__name__, "make_cell")
)
cg.extend_output(create_call_function(0, False))
cg.add_cache(var)
var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
elif var.source is None:
var.source = LocalCellSource(var.local_name)
elif isinstance(var.mutation_type, AttributeMutationNew):
if isinstance(var, variables.AutogradFunctionContextVariable):
unimplemented("AutogradFunctionContextVariable escaped")
cg.add_push_null(
lambda: cg.load_import_from(
utils.__name__, self.get_new_function(var)
)
)
cg(var.mutation_type.cls_source)
if isinstance(var, variables.UserDefinedTupleVariable) and var.new_args:
cg(var.new_args)
cg.extend_output(create_call_function(2, False))
else:
cg.extend_output(create_call_function(1, False))
cg.add_cache(var)
var.source = LocalSource(cg.tempvars[var])
else:
# The remaning cases here are `AttributeMutationExisting` and
# `MutableSideEffects`, which have sources already.
assert var.source is not None
for ctx, args in self.save_for_backward:
cg(ctx.source)
cg.load_method("save_for_backward")
for arg in args:
cg(arg)
cg.extend_output(
[
*create_call_method(len(args)),
create_instruction("POP_TOP"),
]
)
def register_hook(self, tensor, hook, handle, name):
assert isinstance(tensor, variables.TensorVariable)
assert isinstance(hook, variables.VariableTracker)
assert (
isinstance(handle, variables.RemovableHandleVariable)
and handle.is_mutable()
)
assert hasattr(torch.Tensor, name)
idx = len(self.tensor_hooks.keys())
# duplicate index possible because of self.remove_hook()
while idx in self.tensor_hooks:
idx += 1
self.tensor_hooks[idx] = (tensor, hook, handle, name)
assert not handle.idx
handle.idx = idx
def remove_hook(self, idx):
del self.tensor_hooks[idx]
def codegen_hooks(self, cg):
for (
tensor,
hook,
handle,
name,
) in self.tensor_hooks.values():
# Note: [On tensor.register_hook]
#
# register_hook on a tensor, AKA backward hooks, have slightly nuanced differences in how they are implemented
# when it comes to hooks on objects with sources (inputs, params) vs objects without sources (intermediaries).
#
# For tensors with a source, we bypass direct inclusion of register_hook calls in the graph.
# Instead, these are tracked and stashed as a global variable, enabling their association with tensors in
# the residuals. During dynamo's frame creation, these hooks are invoked seamlessly on known reconstructible/fetch-able
# tensors. Because a source indicates knowledge of this object outside the torch compile region, and
# because we are running residuals firmly before .backward() can be run, it is sound to invoke
# `register_hook` on a known tensor.
#
# For tensors without a source, we support a limited subset of hooks. Global functions only, and
# compiled_autograd must be enabled or we will graph break.
#
# Handling the Handle: When a user retains the register_hook result in a handle, we intercept the
# STORE_FAST operation to record the user-designated local variable name. This ensures the reconstructed
# bytecode retains this name. If no handle is defined, we simply pop the generated value to keep the
# stack intact.
#
# Dynamo Tensor Hooks Workflow:
# - Functions passed to register_hook are lifted globally.
# - For tensors with sources:
# - In the "side_effects" phase of codegen, we iterate over tensors with hooks to:
# - Generate the tensor.
# - Issue a register_hook call on the tensor, linking to the globally stored function.
# - Incorporate a handle if one was established in the eager phase.
# - For tensors without sources:
# - We don't generate any instructions for registering a hook.
# - Handles from intermediary hooks are NYI.
# - We produce a call function that utilizes the trace_wrapped higher order op, closing over it.
# - We then manually insert the call function above into the graph.
# - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST.
assert tensor.source, "Hooks on non input tensors NYI - should not get here"
def gen_fn():
cg(tensor)
cg.extend_output([cg.create_load_attr(name)])
cg.add_push_null(gen_fn)
cg(hook)
cg.extend_output(create_call_function(1, False))
# Adding the handle to the cache means RemovableHandleVariable().reconstruct() will
# be associated with the return value of register_hook(). This consumes the top of stack.
cg.add_cache(handle)
def get_ca_final_callbacks_var(self):
from .variables.base import ValueMutationNew
if self.ca_final_callbacks_var is None:
self.ca_final_callbacks_var = variables.ListVariable(
[], mutation_type=ValueMutationNew()
)
return self.ca_final_callbacks_var
def codegen_update_mutated(self, cg: PyCodegen):
suffixes = []
for var in self._get_modified_vars():
if isinstance(var, variables.ListVariable):
# old[:] = new
cg(var, allow_cache=False) # Don't codegen via source
cg(var.source) # type: ignore[attr-defined]
cg.extend_output(
[
cg.create_load_const(None),
cg.create_load_const(None),
create_instruction("BUILD_SLICE", arg=2),
]
)
suffixes.append([create_instruction("STORE_SUBSCR")])
elif isinstance(var, variables.lists.DequeVariable):
# For limited maxlen, the order of operations matter for side
# effect, but we currently don't track the order, so no support.
if not (
isinstance(var.maxlen, variables.ConstantVariable)
and var.maxlen.value is None
):
unimplemented("side effect on existing deque with limited maxlen")
# old.extend(new), this runs last
cg(var.source)
cg.load_method("extend")
cg(var, allow_cache=False) # Don't codegen via source
suffixes.append(
[
*create_call_method(1),
create_instruction("POP_TOP"),
]
)
# old.clear(), this runs first
cg(var.source)
cg.load_method("clear")
suffixes.append(
[
*create_call_method(0),
create_instruction("POP_TOP"),
]
)
elif isinstance(var, variables.ConstDictVariable):
# Reconstruct works as follow:
# (1) Skip codegen if there are no new items
# (2) codegen(...) each pair of key/value
# (3) create a new dictionary with the pairs of key/values above
# (4) clear the original dictionary
# + only if a key was removed from the input dict
# (5) update the original dictionary with the dict created in (2)
if var.has_new_items():
cg(var.source) # type: ignore[attr-defined]
cg.load_method("update")
cg(var, allow_cache=False) # Don't codegen via source
if var.should_reconstruct_all:
cg(var.source) # type: ignore[attr-defined]
cg.load_method("clear")
suffixes.append(
[
*create_call_method(1), # update
create_instruction("POP_TOP"),
]
)
if var.should_reconstruct_all:
# clear will appear before "update" as the suffixes are
# applied in reverse order.
suffixes.append(
[
*create_call_method(0), # clear
create_instruction("POP_TOP"),
]
)
elif isinstance(
var, variables.torch_function.TorchFunctionModeStackVariable
):
# Needed in the finally block for stack restoration
cg.add_push_null(
lambda: cg.load_import_from(
utils.__name__, "get_torch_function_mode_stack"
)
)
cg.call_function(0, False)
name = variables.torch_function.get_prev_stack_var_name()
cg.code_options["co_varnames"] += (name,)
cg.append_output(create_instruction("STORE_FAST", argval=name))
cg.add_push_null(
lambda: cg.load_import_from(
utils.__name__, "set_torch_function_mode_stack"
)
)
cg.foreach(var.symbolic_stack)
cg.append_output(
create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))
)
cg.call_function(1, False)
cg.append_output(create_instruction("POP_TOP"))
elif isinstance(var, variables.CellVariable) and var.local_name is not None:
# Emit more readable and performant bytecode.
# TODO generalize this for cells created during inlining.
if var in self.store_attr_mutations:
contents_var = self.load_cell(var)
cg(contents_var)
suffixes.append([cg.create_store_deref(var.local_name)])
elif self.is_attribute_mutation(var):
if isinstance(var, variables.UserDefinedDictVariable):
# Do dict related update manually here. The store_attr
# mutations will be applied later.
varname_map = {}
for name in _manual_dict_setitem.__code__.co_varnames:
varname_map[name] = cg.tx.output.new_var()
try:
mro_index = type(var.value).__mro__.index(
collections.OrderedDict
)
except ValueError:
mro_index = type(var.value).__mro__.index(dict)
cg.extend_output(
[
create_instruction("LOAD_CONST", argval=mro_index),
create_instruction(
"STORE_FAST", argval=varname_map["mro_index"]
),
]
)
cg(var.source) # type: ignore[attr-defined]
cg.extend_output(
[
create_instruction(
"STORE_FAST", argval=varname_map["dict_to"]
)
]
)
cg(var._dict_vt, allow_cache=False) # Don't codegen via source
cg.extend_output(
[
create_instruction(
"STORE_FAST", argval=varname_map["dict_from"]
)
]
)
dict_update_insts = bytecode_from_template(
_manual_dict_setitem, varname_map=varname_map
)
suffixes.append(
[
*dict_update_insts,
create_instruction("POP_TOP"),
]
)
# Applying mutations involves two steps: 1) Push all
# reconstructed objects onto the stack. 2) Call STORE_ATTR to
# apply the mutations.
#
# Dynamo must ensure that mutations are applied in the same
# order as in the original program. Therefore, two reverse
# operations occur below.
#
# The first reverse operation concerns `suffixes`. We apply
# suffixes in reverse order due to the way Python handles the
# stack. In Step 1, we push all reconstructed objects onto the
# stack, but the item at the top of the stack refers to the last
# attribute in the mutation order. If not fixed, this will apply
# the mutations of attributes in the reverse order. To account
# for this reversal, we iterate through the mutable attributes
# in reverse order.
for name, value in reversed(
self.store_attr_mutations.get(var, {}).items()
):
if isinstance(var, variables.NewGlobalVariable):
cg.tx.output.update_co_names(name)
cg(value)
assert isinstance(var.source, GlobalSource) # type: ignore[attr-defined]
suffixes.append(
[create_instruction("STORE_GLOBAL", argval=name)]
)
elif isinstance(value, variables.DeletedVariable):
if isinstance(
var.mutation_type, AttributeMutationExisting
) and hasattr(getattr(var, "value", None), name):
cg.tx.output.update_co_names(name)
cg(var.source)
suffixes.append(
[create_instruction("DELETE_ATTR", argval=name)]
)
elif (
isinstance(var, variables.UserDefinedObjectVariable)
and var.needs_slow_setattr()
):
# __setattr__ is defined on this object, so call object.__setattr__ directly
cg.load_import_from("builtins", "object")
cg.load_method("__setattr__")
cg(var.source) # type: ignore[attr-defined]
cg(variables.ConstantVariable(name))
cg(value)
suffixes.append(
[*create_call_method(3), create_instruction("POP_TOP")]
)
else:
cg.tx.output.update_co_names(name)
cg(value)
cg(var.source)
suffixes.append([create_instruction("STORE_ATTR", argval=name)])
elif isinstance(var, variables.ListIteratorVariable):
for _ in range(var.index):
cg.add_push_null(
lambda: cg.load_import_from(utils.__name__, "iter_next")
)
cg(var.source) # type: ignore[attr-defined]
cg.call_function(1, False)
cg.pop_top()
elif isinstance(var, variables.RandomVariable):
# set correct random seed state
def gen_fn():
cg(var.source) # type: ignore[attr-defined]
cg.load_attr("setstate")
cg.add_push_null(gen_fn)
cg(var.wrap_state(var.random.getstate()))
suffixes.append(
[
*create_call_function(1, False), # setstate
create_instruction("POP_TOP"),
]
)
else:
raise AssertionError(type(var))
# do all the actual mutations at the very end to handle dependencies
for suffix in reversed(suffixes):
cg.extend_output(suffix)
def is_empty(self):
return not (
any(map(self.is_modified, self.id_to_variable.values()))
or self.tensor_hooks
or self.save_for_backward
or self.tensor_hooks
)
def clear(self):
self.keepalive.clear()
self.id_to_variable.clear()
@contextlib.contextmanager
def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): # type: ignore[name-defined] # noqa: F821
assert tx.output.current_tracer.under_activation_checkpoint
orig_val = tx.output.current_tracer.allow_side_effects_under_checkpoint
try:
tx.output.current_tracer.allow_side_effects_under_checkpoint = True
yield
finally:
tx.output.current_tracer.allow_side_effects_under_checkpoint = orig_val