pytorch/torch/_dynamo/variables/optimizer.py
Michael Lazos c46af25bb3 Initialize optimizer in dynamo to avoid graph break and tracing slowness (#102640)
On calls to `_init_group` rather than tracing through it, extract python values from the arguments, and call the initialization. This avoids having to trace this function which is very slow with large parameters, and also avoids graph breaking on it. This is sound in this case because the state is only initialized once in the eager case. Guards on the state and params are generated explicitly rather than via tracing the initialization.

Caveats:
`_init_group` also gathers various state tensors into lists via mutating list arguments to pass to the functional optimizer implementation. These state tensors exist on the optimizer itself, but we don't know exactly how the gathering is done and which tensors correspond to which attributes of the optimizer module (each optimizer has different states). To rectify this, we keep weak_ptrs to all of the tensors collected in the lists in globals (similar to how parameter keys are stored for dictionaries). These pointers are guaranteed to be alive as long as the optimizer object is alive if the internal state is not interfered with and they are guarded with weakref guards

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102640
Approved by: https://github.com/jansel
2023-06-03 15:49:51 +00:00

121 lines
4.4 KiB
Python

from typing import Dict, List
import torch
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
from ..utils import global_key_name
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import ListVariable
from .misc import GetAttrVariable
from .user_defined import UserDefinedObjectVariable
class ArgMappingException(Exception):
pass
class OptimizerVariable(UserDefinedObjectVariable):
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
"""This is an optimization to avoid tracing the very slow intialization of the optimizer"""
if name == "_init_group":
try:
py_args, py_kwargs = self.get_python_args(*args, **kwargs)
self.value._init_group(*py_args, **py_kwargs)
self.install_guards(tx)
self.update_list_args(tx, args, kwargs, py_args, py_kwargs)
return ConstantVariable(None)
except ArgMappingException:
# trace normally if we can't map args
pass
return super().call_method(tx, name, args, kwargs)
def map_grads_to_sources(self):
"""Map the optimizer's grads to their sources"""
self.grad_to_source = {}
for g_ind, group in enumerate(self.value.param_groups):
group_source = GetItemSource(AttrSource(self.source, "param_groups"), g_ind)
for p_ind, p in enumerate(group["params"]):
if p.grad is not None:
self.grad_to_source[p.grad] = AttrSource(
GetItemSource(GetItemSource(group_source, "params"), p_ind),
"grad",
)
def var_getattr(self, tx, name):
if name == "_init_group":
return GetAttrVariable(self, name)
return super().var_getattr(tx, name)
def get_python_args(self, *args, **kwargs):
"""Get python values equivalent to the variable tracker args"""
def map_arg(arg):
if isinstance(arg, ConstantVariable):
return arg.as_python_constant()
elif isinstance(arg, ListVariable) and not arg.items:
return []
elif (
isinstance(arg, ConstDictVariable)
and isinstance(arg.source, GetItemSource)
and isinstance(arg.source.base, AttrSource)
and arg.source.base.member == "param_groups"
):
return self.value.param_groups[arg.source.index]
raise ArgMappingException()
new_args = [map_arg(arg) for arg in args]
new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}
return new_args, new_kwargs
def install_guards(self, tx):
from .builder import VariableBuilder
state_dict_var = VariableBuilder(tx, AttrSource(self.source, "state"))(
self.value.state
)
tx.output.guards.update(state_dict_var.guards)
group_guards = VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups
)
tx.output.guards.update(group_guards.guards)
def wrap_tensor(self, tx, tensor_value):
"""Wrap state tensor in a TensorVariable"""
from .builder import VariableBuilder
# don't add weakref guards for grads, they will possibly change on
# each iteration
if tensor_value in self.grad_to_source:
return VariableBuilder(tx, self.grad_to_source[tensor_value])(tensor_value)
else:
tx.store_dict_key(global_key_name(tensor_value), tensor_value)
return VariableBuilder(
tx, GlobalWeakRefSource(global_key_name(tensor_value))
)(tensor_value)
def update_list_args(self, tx, args, kwargs, py_args, py_kwargs):
"""Update the args and kwargs to the traced optimizer call"""
self.map_grads_to_sources()
for arg, py_arg in zip(args, py_args):
if isinstance(arg, ListVariable) and all(
isinstance(t, torch.Tensor) for t in py_arg
):
tensor_vars = ListVariable(
[self.wrap_tensor(tx, t) for t in py_arg],
mutable_local=MutableLocal(),
)
arg.call_method(tx, "extend", (tensor_vars,), {})