mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
On calls to `_init_group` rather than tracing through it, extract python values from the arguments, and call the initialization. This avoids having to trace this function which is very slow with large parameters, and also avoids graph breaking on it. This is sound in this case because the state is only initialized once in the eager case. Guards on the state and params are generated explicitly rather than via tracing the initialization. Caveats: `_init_group` also gathers various state tensors into lists via mutating list arguments to pass to the functional optimizer implementation. These state tensors exist on the optimizer itself, but we don't know exactly how the gathering is done and which tensors correspond to which attributes of the optimizer module (each optimizer has different states). To rectify this, we keep weak_ptrs to all of the tensors collected in the lists in globals (similar to how parameter keys are stored for dictionaries). These pointers are guaranteed to be alive as long as the optimizer object is alive if the internal state is not interfered with and they are guarded with weakref guards Pull Request resolved: https://github.com/pytorch/pytorch/pull/102640 Approved by: https://github.com/jansel
121 lines
4.4 KiB
Python
121 lines
4.4 KiB
Python
from typing import Dict, List
|
|
|
|
import torch
|
|
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
|
|
from ..utils import global_key_name
|
|
|
|
from .base import MutableLocal, VariableTracker
|
|
from .constant import ConstantVariable
|
|
from .dicts import ConstDictVariable
|
|
from .lists import ListVariable
|
|
from .misc import GetAttrVariable
|
|
from .user_defined import UserDefinedObjectVariable
|
|
|
|
|
|
class ArgMappingException(Exception):
|
|
pass
|
|
|
|
|
|
class OptimizerVariable(UserDefinedObjectVariable):
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
"""This is an optimization to avoid tracing the very slow intialization of the optimizer"""
|
|
if name == "_init_group":
|
|
try:
|
|
py_args, py_kwargs = self.get_python_args(*args, **kwargs)
|
|
self.value._init_group(*py_args, **py_kwargs)
|
|
self.install_guards(tx)
|
|
self.update_list_args(tx, args, kwargs, py_args, py_kwargs)
|
|
return ConstantVariable(None)
|
|
except ArgMappingException:
|
|
# trace normally if we can't map args
|
|
pass
|
|
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def map_grads_to_sources(self):
|
|
"""Map the optimizer's grads to their sources"""
|
|
self.grad_to_source = {}
|
|
for g_ind, group in enumerate(self.value.param_groups):
|
|
group_source = GetItemSource(AttrSource(self.source, "param_groups"), g_ind)
|
|
for p_ind, p in enumerate(group["params"]):
|
|
if p.grad is not None:
|
|
self.grad_to_source[p.grad] = AttrSource(
|
|
GetItemSource(GetItemSource(group_source, "params"), p_ind),
|
|
"grad",
|
|
)
|
|
|
|
def var_getattr(self, tx, name):
|
|
if name == "_init_group":
|
|
return GetAttrVariable(self, name)
|
|
|
|
return super().var_getattr(tx, name)
|
|
|
|
def get_python_args(self, *args, **kwargs):
|
|
"""Get python values equivalent to the variable tracker args"""
|
|
|
|
def map_arg(arg):
|
|
if isinstance(arg, ConstantVariable):
|
|
return arg.as_python_constant()
|
|
elif isinstance(arg, ListVariable) and not arg.items:
|
|
return []
|
|
elif (
|
|
isinstance(arg, ConstDictVariable)
|
|
and isinstance(arg.source, GetItemSource)
|
|
and isinstance(arg.source.base, AttrSource)
|
|
and arg.source.base.member == "param_groups"
|
|
):
|
|
return self.value.param_groups[arg.source.index]
|
|
|
|
raise ArgMappingException()
|
|
|
|
new_args = [map_arg(arg) for arg in args]
|
|
new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}
|
|
|
|
return new_args, new_kwargs
|
|
|
|
def install_guards(self, tx):
|
|
from .builder import VariableBuilder
|
|
|
|
state_dict_var = VariableBuilder(tx, AttrSource(self.source, "state"))(
|
|
self.value.state
|
|
)
|
|
tx.output.guards.update(state_dict_var.guards)
|
|
|
|
group_guards = VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
|
|
self.value.param_groups
|
|
)
|
|
tx.output.guards.update(group_guards.guards)
|
|
|
|
def wrap_tensor(self, tx, tensor_value):
|
|
"""Wrap state tensor in a TensorVariable"""
|
|
from .builder import VariableBuilder
|
|
|
|
# don't add weakref guards for grads, they will possibly change on
|
|
# each iteration
|
|
if tensor_value in self.grad_to_source:
|
|
return VariableBuilder(tx, self.grad_to_source[tensor_value])(tensor_value)
|
|
else:
|
|
tx.store_dict_key(global_key_name(tensor_value), tensor_value)
|
|
return VariableBuilder(
|
|
tx, GlobalWeakRefSource(global_key_name(tensor_value))
|
|
)(tensor_value)
|
|
|
|
def update_list_args(self, tx, args, kwargs, py_args, py_kwargs):
|
|
"""Update the args and kwargs to the traced optimizer call"""
|
|
self.map_grads_to_sources()
|
|
for arg, py_arg in zip(args, py_args):
|
|
if isinstance(arg, ListVariable) and all(
|
|
isinstance(t, torch.Tensor) for t in py_arg
|
|
):
|
|
tensor_vars = ListVariable(
|
|
[self.wrap_tensor(tx, t) for t in py_arg],
|
|
mutable_local=MutableLocal(),
|
|
)
|
|
arg.call_method(tx, "extend", (tensor_vars,), {})
|