from typing import Dict, List import torch from ..guards import GuardBuilder from ..source import AttrSource, GetItemSource, GlobalWeakRefSource from ..utils import global_key_name from .base import MutableLocal, VariableTracker from .constant import ConstantVariable from .dicts import ConstDictVariable from .lists import ListVariable from .misc import GetAttrVariable from .user_defined import UserDefinedObjectVariable class ArgMappingException(Exception): pass class GuardInstallException(Exception): pass class OptimizerVariable(UserDefinedObjectVariable): def __init__(self, value, grad_to_source=None, tensor_to_source=None, **kwargs): super().__init__(value, **kwargs) for group in self.value.param_groups: if "capturable" in group: group["capturable"] = True if grad_to_source is None: self.grad_to_source = {} if tensor_to_source is None: self.tensor_to_source = {} def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": """This is an optimization to avoid tracing the very slow intialization of the optimizer""" if name == "_init_group": try: py_args, py_kwargs = self.get_python_args(*args, **kwargs) self.value._init_group(*py_args, **py_kwargs) self.map_sources_and_install_guards(tx) self.update_list_args(tx, args, kwargs, py_args, py_kwargs) return ConstantVariable(None) except (ArgMappingException, GuardInstallException) as _: # trace normally if we can't map args or install guards correctly pass return super().call_method(tx, name, args, kwargs) def var_getattr(self, tx, name): if name == "_init_group": return GetAttrVariable(self, name) return super().var_getattr(tx, name) def get_python_args(self, *args, **kwargs): """Get python values equivalent to the variable tracker args""" def map_arg(arg): if isinstance(arg, ConstantVariable): return arg.as_python_constant() elif isinstance(arg, ListVariable) and not arg.items: return [] elif ( isinstance(arg, ConstDictVariable) and isinstance(arg.source, GetItemSource) and isinstance(arg.source.base, AttrSource) and arg.source.base.member == "param_groups" ): return self.value.param_groups[arg.source.index] raise ArgMappingException() new_args = [map_arg(arg) for arg in args] new_kwargs = {k: map_arg(v) for k, v in kwargs.items()} return new_args, new_kwargs def map_sources_and_install_guards(self, tx): from .builder import VariableBuilder self.grad_to_source = {} self.tensor_to_source = {} for g_ind, group in enumerate(self.value.param_groups): group_source = GetItemSource(AttrSource(self.source, "param_groups"), g_ind) for p_ind, p in enumerate(group["params"]): param_source = GetItemSource( GetItemSource(group_source, "params"), p_ind ) self.tensor_to_source[p] = param_source if p.grad is not None: self.grad_to_source[p.grad] = AttrSource( param_source, "grad", ) # state guards take a long time to generate # so we manually generate them here guards = set() state_source = AttrSource(self.source, "state") guards.add(state_source.make_guard(GuardBuilder.DICT_KEYS)) for p, value in self.value.state.items(): tx.store_dict_key(global_key_name(p), p) p_state_source = GetItemSource(state_source, self.tensor_to_source[p]) guards.add(p_state_source.make_guard(GuardBuilder.DICT_KEYS)) for k, v in value.items(): if ( isinstance(v, torch.Tensor) and v not in self.grad_to_source and v not in self.tensor_to_source ): self.tensor_to_source[v] = GetItemSource(p_state_source, k) elif v is None or isinstance(v, (bool, int, float, str)): guards.add( GetItemSource(p_state_source, k).make_guard( GuardBuilder.CONSTANT_MATCH ) ) else: raise GuardInstallException() tx.output.guards.update(guards) group_guards = VariableBuilder(tx, AttrSource(self.source, "param_groups"))( self.value.param_groups ) tx.output.guards.update(group_guards.guards) def wrap_tensor(self, tx, tensor_value): """Wrap state tensor in a TensorVariable""" from .builder import VariableBuilder # If we have a source for a tensor already use it, # if we have not seen a tensor before, stash and use a # global weak ref source, since it must be an optimizer tensor # that we have missed if tensor_value in self.tensor_to_source: return VariableBuilder(tx, self.tensor_to_source[tensor_value])( tensor_value ) elif tensor_value in self.grad_to_source: return VariableBuilder(tx, self.grad_to_source[tensor_value])(tensor_value) else: tx.store_dict_key(global_key_name(tensor_value), tensor_value) return VariableBuilder( tx, GlobalWeakRefSource(global_key_name(tensor_value)) )(tensor_value) def update_list_args(self, tx, args, kwargs, py_args, py_kwargs): """Update the args and kwargs to the traced optimizer call""" for arg, py_arg in zip(args, py_args): if isinstance(arg, ListVariable) and all( isinstance(t, torch.Tensor) for t in py_arg ): tensor_vars = ListVariable( [self.wrap_tensor(tx, t) for t in py_arg], mutable_local=MutableLocal(), recursively_contains={}, ) tx.replace_all(arg, tensor_vars)