pytorch/torch/_dynamo/variables/optimizer.py
Michael Lazos fbeca60b1f Remove replace_all and make VTs mutable (#113725)
1.  Removes calls to `replace_all` and `clone` and makes VTs mutable.
2. Properly handles Tuple Iterator mutation. Previously TupleIterator variables would only be properly reconstructed if they were advanced at least once in a frame. On calls to `next`, the source information would be lost (due to constructing a new iterator without using builder), which would ensure that during codegen the variable would be reconstructed from scratch. Now that VTs are mutated, the source is never lost, so we need to properly track mutation and handle it by replaying calls to `next` at the end of the modified bytecode.
3. Added test for checking iadd side effects, this was missing in our unit test coverage.
4. Fixed two incorrect sources, DelayGraphBreakVariable, and UserMethodVariable both relied on setting the source to AttrSource(parent, name) at the callsite of `var_getattr`.
5. Fixed a bug in inplace adding for lists, it would set the resulting VariableTracker's source to `None` which would utilize a different reconstruct path in codegen. Now this is handled explicitly by reconstructing vars when allow_cache=`False`, so that during side effect replay, the mutated var is correctly updated.

In subsequent PRs:
* Refactoring side effect tracking to be significantly simpler (I think we only need an `is_modified` flag)
* Refactor `next_variables` iterator to match the signature of `next`
* Remove all references to `options` in the code
* Refactor VTs representing mutable collections to implement their own mutation update handling
* Remove clone and/or make it specific to lists for creating slices
* Add mutation tracking/replay for sets
* Add mutation tracking/replay for iter.py
* Removing setting source in builder (it's set at the top level after a var is returned)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113725
Approved by: https://github.com/jansel
2023-12-10 09:31:21 +00:00

213 lines
8.1 KiB
Python

import weakref
from typing import Dict, List
import torch
from ..decorators import mark_static_address
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
from ..utils import global_key_name
from .base import VariableTracker
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import ListVariable
from .misc import GetAttrVariable
from .user_defined import UserDefinedObjectVariable
class ArgMappingException(Exception):
pass
class GuardInstallException(Exception):
pass
class OptimizerVariable(UserDefinedObjectVariable):
def __init__(
self,
value,
grad_to_source=None,
static_tensor_names=None,
tensor_to_source=None,
**kwargs,
):
super().__init__(value, **kwargs)
for group in self.value.param_groups:
if "capturable" in group:
group["capturable"] = True
for p in group["params"]:
mark_static_address(p, guard=False)
self.grad_to_source = grad_to_source or {}
self.tensor_to_source = tensor_to_source or {}
self.static_tensor_names = static_tensor_names or set()
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
"""This is an optimization to avoid tracing the very slow initialization of the optimizer"""
if name == "_init_group":
try:
py_args, py_kwargs = self.get_python_args(*args, **kwargs)
ret_val = self.value._init_group(*py_args, **py_kwargs)
self.map_sources_and_install_guards(tx)
self.update_list_args(tx, args, kwargs, py_args, py_kwargs)
# stash a weak_ptr to optimizer to invalidate code
# if the optimizer object dies
tx.store_global_weakref(self.get_global_name(), self.value)
self.create_finalizer(tx)
# This is currently safe only because the only actual `ret_val`s returned
# by the `_init_group` of existing optimizers are properties that are invariant
# to the input tensors (e.g. dtype, layout). Changing these would trigger a
# recompilation and hence never result in the wrong specialization of `ret_val`.
return ConstantVariable.create(ret_val)
except (ArgMappingException, GuardInstallException) as _:
# trace normally if we can't map args or install guards correctly
pass
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
if name == "_init_group":
return GetAttrVariable(self, name)
return super().var_getattr(tx, name)
def get_python_args(self, *args, **kwargs):
"""Get python values equivalent to the variable tracker args"""
def map_arg(arg):
if isinstance(arg, ConstantVariable):
return arg.as_python_constant()
elif isinstance(arg, ListVariable) and not arg.items:
return []
elif (
isinstance(arg, ConstDictVariable)
and isinstance(arg.source, GetItemSource)
and isinstance(arg.source.base, AttrSource)
and arg.source.base.member == "param_groups"
):
return self.value.param_groups[arg.source.index]
raise ArgMappingException()
new_args = [map_arg(arg) for arg in args]
new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}
return new_args, new_kwargs
def map_sources_and_install_guards(self, tx):
from .builder import VariableBuilder
self.grad_to_source = {}
self.tensor_to_source = {}
for g_ind, group in enumerate(self.value.param_groups):
group_source = GetItemSource(AttrSource(self.source, "param_groups"), g_ind)
for p_ind, p in enumerate(group["params"]):
param_source = GetItemSource(
GetItemSource(group_source, "params"), p_ind
)
self.tensor_to_source[p] = param_source
if p.grad is not None:
self.grad_to_source[p.grad] = AttrSource(
param_source,
"grad",
)
# state guards take a long time to generate
# so we manually generate them here
state_source = AttrSource(self.source, "state")
install_guard(state_source.make_guard(GuardBuilder.DICT_KEYS))
for p, value in self.value.state.items():
tx.store_global_weakref(global_key_name(p), p)
p_state_source = GetItemSource(state_source, self.tensor_to_source[p])
install_guard(p_state_source.make_guard(GuardBuilder.DICT_KEYS))
for k, v in value.items():
if (
isinstance(v, torch.Tensor)
and v not in self.grad_to_source
and v not in self.tensor_to_source
):
self.tensor_to_source[v] = GetItemSource(p_state_source, k)
elif v is None or isinstance(v, (bool, int, float, str)):
install_guard(
GetItemSource(p_state_source, k).make_guard(
GuardBuilder.CONSTANT_MATCH
)
)
else:
raise GuardInstallException()
# this next line has the side effect of installing guards
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups
).recursive_realize()
def wrap_tensor(self, tx, tensor_value):
"""Wrap state tensor in a TensorVariable"""
from .builder import VariableBuilder
# If we have a source for a tensor already use it,
# if we have not seen a tensor before, stash and use a
# global weak ref source, since it must be an optimizer tensor
# that we have missed
if tensor_value in self.tensor_to_source:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value, guard=False)
builder = VariableBuilder(tx, self.tensor_to_source[tensor_value])
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
elif tensor_value in self.grad_to_source:
builder = VariableBuilder(tx, self.grad_to_source[tensor_value])
else:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value, guard=False)
tx.store_global_weakref(global_key_name(tensor_value), tensor_value)
builder = VariableBuilder(
tx, GlobalWeakRefSource(global_key_name(tensor_value))
)
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
result = builder(tensor_value)
return result
def update_list_args(self, tx, args, kwargs, py_args, py_kwargs):
"""Update the args and kwargs to the traced optimizer call"""
for arg, py_arg in zip(args, py_args):
if isinstance(arg, ListVariable) and all(
isinstance(t, torch.Tensor) for t in py_arg
):
tx.output.side_effects.mutation(arg)
arg.items.extend([self.wrap_tensor(tx, t) for t in py_arg])
def create_finalizer(self, tx):
names_to_delete = self.static_tensor_names
value = self.value
tc = tx.output.tracing_context
def init_finalizer(gm):
def clear_static_tensor_refs():
for name in names_to_delete:
gm._buffers.pop(name, None)
gm._parameters.pop(name, None)
if tc.params_flat:
tc.params_flat.clear()
weakref.finalize(value, clear_static_tensor_refs)
tx.output.add_graph_finalizer(init_finalizer)
def get_global_name(self):
return f"__optimizer_{id(self.value)}"