pytorch/torch/_dynamo/variables/optimizer.py
Animesh Jain eb9c127341 [dynamo][optimizers] Install ID_GUARDED tensors into the Fx graph (#147824)
Earlier, with inline flag we were lifting id-guarded tensors to the inputs to the Fx graph. But this offers no benefit. Main idea behind lifting parameters as inputs was to reuse the compilation units across many instances of the nn-module. However, if we are guarding on the `id`, we are explicitly specializing the compiled artifact to the parameter.

This PR installs the parameters back into the graph. The benefit is removal of all pre-graph bytecode to extract the id-guarded tensors from locals/globals. This increases speedup from 1.67x to 1.75x for an internal model that has large number of optimizer parameters.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147824
Approved by: https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@meta.com>
2025-02-28 03:22:11 +00:00

409 lines
16 KiB
Python

# mypy: ignore-errors
"""
This module implements variable tracking for PyTorch optimizers during Dynamo tracing.
The OptimizerVariable class provides specialized handling for optimizer instances by:
- Optimizing the tracing of expensive optimizer initialization
- Managing optimizer state and parameter group tracking
- Handling tensor sources and guards for optimizer state tensors
- Supporting CUDA graph execution through static tensor address management
- Providing special handling for parameter gradients and optimizer state tensors
Key features include:
- Efficient initialization tracing via _init_group optimization
- Automatic marking of optimizer state tensors as static for CUDA graphs
- Proper source tracking for parameter groups, gradients, and state tensors
- Guard installation for optimizer state structure
- Support for both CPU and GPU tensor handling
- Cleanup of static tensor references via finalizers
The module integrates with Dynamo's broader tracing system while providing
optimizer-specific optimizations and safety guarantees.
"""
import logging
import weakref
from typing import TYPE_CHECKING
import torch
from torch._logging import getArtifactLogger
from torch.utils._pytree import tree_map_only
from ..guards import GuardBuilder, install_guard
from ..source import (
AttrSource,
ConstDictKeySource,
DictGetItemSource,
GetItemSource,
GlobalWeakRefSource,
GradSource,
)
from ..utils import GLOBAL_KEY_PREFIX
from .base import VariableTracker
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import ListVariable
from .misc import GetAttrVariable
from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
class ArgMappingException(Exception):
pass
class GuardInstallException(Exception):
pass
perf_hint_log = getArtifactLogger(__name__, "perf_hints")
def _is_static_for_cudagraphs(x):
from torch._inductor.cudagraph_trees import get_manager
if x.is_cuda:
manager = get_manager(x.device.index, False)
is_static_address = torch._dynamo.utils.get_static_address_type(x) is not None
if manager:
return (
is_static_address
or manager.current_node._is_cuda_graph_recorded_tensor(x)
)
else:
return is_static_address
else:
# Don't print a warning for non-cuda tensors
return True
class OptimizerVariable(UserDefinedObjectVariable):
_nonvar_fields = {
"grad_to_source",
"tensor_to_source",
"static_tensor_names",
*UserDefinedObjectVariable._nonvar_fields,
}
def __init__(
self,
value,
grad_to_source=None,
static_tensor_names=None,
tensor_to_source=None,
**kwargs,
) -> None:
super().__init__(value, **kwargs)
self.grad_to_source = grad_to_source or {}
self.tensor_to_source = tensor_to_source or {}
self.static_tensor_names = static_tensor_names or set()
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
"""This is an optimization to avoid tracing the very slow initialization of the optimizer"""
if name == "_init_group":
try:
self.graph_break_if_pending_mutation(tx)
self.move_step_if_cpu()
py_args, py_kwargs = self.get_python_args(*args, **kwargs)
ret_val = self.value._init_group(*py_args, **py_kwargs)
self.map_sources_and_install_guards(tx)
self.update_list_args(tx, args, kwargs, py_args, py_kwargs)
# stash a weak_ptr to optimizer to invalidate code
# if the optimizer object dies
mangled_name = f"__optimizer_{id(self.value)}"
tx.store_global_weakref_by_id(mangled_name, self.value)
self.create_finalizer(tx)
# This is currently safe only because the only actual `ret_val`s returned
# by the `_init_group` of existing optimizers are properties that are invariant
# to the input tensors (e.g. dtype, layout). Changing these would trigger a
# recompilation and hence never result in the wrong specialization of `ret_val`.
return ConstantVariable.create(ret_val)
except (ArgMappingException, GuardInstallException) as _:
# trace normally if we can't map args or install guards correctly
pass
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx: "InstructionTranslator", name):
# Note: this allows us to intercept the call in call_method
# in the typical case, we return a UserMethodVariable
# which will directly inline
if name in ("_init_group", "step"):
return GetAttrVariable(self, name, source=AttrSource(self.source, name))
if name == "param_groups":
from ..decorators import mark_static_address
for group in self.value.param_groups:
for p in group["params"]:
mark_static_address(p)
self._set_capturable(tx)
return super().var_getattr(tx, name)
def graph_break_if_pending_mutation(self, tx):
# If there are pending mutations on a parameter (due to using closure)
# then we need to graph break to allow the python version of the parameter
# to update, so that running _init_group will initialize the states with
# the correct values
for g in self.value.param_groups:
for p in g["params"]:
side_effects = tx.output.side_effects
variable = side_effects.id_to_variable.get(id(p), None)
if variable and side_effects.has_pending_mutation(variable):
from ..exc import Unsupported
raise Unsupported("Pending mutation on parameter")
def _set_capturable(self, tx):
from . import LazyVariableTracker
# We only set capturable if params are on cuda
# and the state is not initialized
def safe_to_set_capturable(group):
all_uninitialized = True
all_gpu = True
for p in group.get("params", []):
all_gpu &= p.is_cuda or p.is_xpu
all_uninitialized &= p not in self.value.state
return "capturable" in group and all_uninitialized and all_gpu
# track indices to not set so we don't need to
# in the variable tracker realize the whole state
# we handle guarding the state specially
for group in self.value.param_groups:
if safe_to_set_capturable(group):
group["capturable"] = True
source = self.source and AttrSource(self.source, "param_groups")
param_groups_vt = LazyVariableTracker.realize_all(
VariableTracker.build(tx, self.value.param_groups, source)
)
for param_group_vt in param_groups_vt.items:
key = ConstDictVariable._HashableTracker(
ConstantVariable.create("capturable")
)
param_group_vt.items[key] = ConstantVariable.create(True)
def get_python_args(self, *args, **kwargs):
"""Get python values equivalent to the variable tracker args"""
def map_arg(arg):
if isinstance(arg, ConstantVariable):
return arg.as_python_constant()
elif isinstance(arg, ListVariable) and not arg.items:
return []
elif (
isinstance(arg, ConstDictVariable)
and isinstance(arg.source, GetItemSource)
and isinstance(arg.source.base, AttrSource)
and arg.source.base.member == "param_groups"
):
return self.value.param_groups[arg.source.index]
raise ArgMappingException
new_args = [map_arg(arg) for arg in args]
new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}
return new_args, new_kwargs
# If users load an old state dictionary,
# it's possible that step could be on the cpu
# if this is the case, move it to the GPU
# corresponding to the parameter
# in most cases this is a no-op because the state is empty
def move_step_if_cpu(self):
for p, state in self.value.state.items():
if "step" in state and state["step"].is_cpu:
state["step"] = state["step"].to(p.device)
def map_sources_and_install_guards(self, tx):
from ..decorators import mark_static_address
from .lazy import LazyVariableTracker
self.grad_to_source = {}
self.tensor_to_source = {}
# Tracing the _init_group is expensive. But we still have to insert the
# necessary guards for _init_group. So, we manually handle insertion of
# guards. We also want to mark all the tensors inside the state dict to
# be static address.
# Mark all the tensors in the state dict to be static address. This has
# to be done first because the variable builder relies on the static
# address annotation.
def mark_static(x):
mark_static_address(x)
tree_map_only(torch.Tensor, mark_static, self.value.state)
# Recursively realize the variable trackers for optim.state and
# optim.param_groups, which recursively install the necessary guards.
params_groups_source = self.source and AttrSource(self.source, "param_groups")
param_groups_vt = LazyVariableTracker.realize_all(
VariableTracker.build(tx, self.value.param_groups, params_groups_source)
)
state_source = self.source and AttrSource(self.source, "state")
state_vt = VariableTracker.build(tx, self.value.state, state_source)
# We need to realize the top level state dict to populate
# the guard locals
state_vt.realize()
tx.output.guard_on_key_order.add(state_source.name())
# Populate self.grad_to_source and self.tensor_to_source so that we can
# manually update_list_args
for group, group_vt in zip(self.value.param_groups, param_groups_vt.items):
# we assume here that all params within a param group
# are initialized similarly
if len(group["params"]) > 0:
for param in group["params"]:
if param.grad is not None:
key_index = None
for i, k in enumerate(self.value.state.keys()):
if k is param:
key_index = i
break
if key_index:
LazyVariableTracker.realize_all(
VariableTracker.build(
tx,
self.value.state[param],
DictGetItemSource(
state_source,
ConstDictKeySource(state_source, key_index),
),
)
)
break
params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params"))
all_static = True
non_static_grads = []
for p_ind, (p, p_vt) in enumerate(
zip(group["params"], params_vt.unpack_var_sequence(tx))
):
param_source = p_vt.source
self.tensor_to_source[p] = param_source
grad_source = GradSource(
param_source,
"grad",
)
if p.grad is not None:
self.grad_to_source[p.grad] = grad_source
if not _is_static_for_cudagraphs(p.grad):
all_static = False
non_static_grads.append(grad_source)
else:
install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH))
# Note: to avoid spam logs only warn if perf hint artifact is enabled
# (NB: artifacts are only enabled at the debug or warning level)
if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG):
non_static_grads = [src.name() for src in non_static_grads]
perf_hint_log.warning(
(
"Grad tensors %s will be copied during cudagraphs execution."
"If using cudagraphs and the grad tensor addresses will be the same across runs,"
" use torch._dynamo.decorators.mark_static_address to elide this copy.",
),
non_static_grads,
)
# We have to again iterate over the state dict to collect the
# tensor_to_source dict. This is used for the finalizer.
for idx, (p, value) in enumerate(self.value.state.items()):
p_state_source = DictGetItemSource(
state_source, ConstDictKeySource(state_source, idx)
)
tx.output.guard_on_key_order.add(p_state_source.name())
for inner_idx, (k, v) in enumerate(value.items()):
if (
isinstance(v, torch.Tensor)
and v not in self.grad_to_source
and v not in self.tensor_to_source
):
self.tensor_to_source[v] = DictGetItemSource(
p_state_source, ConstDictKeySource(p_state_source, inner_idx)
)
def wrap_tensor(self, tx: "InstructionTranslator", tensor_value):
"""Wrap state tensor in a TensorVariable"""
from ..decorators import mark_static_address
# If we have a source for a tensor already use it,
# if we have not seen a tensor before, stash and use a
# global weak ref source, since it must be an optimizer tensor
# that we have missed
if tensor_value in self.tensor_to_source:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value)
source = self.tensor_to_source[tensor_value]
self.static_tensor_names.add(tx.output.module_key_name(source.name()))
elif tensor_value in self.grad_to_source:
source = self.grad_to_source[tensor_value]
else:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value)
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
source = GlobalWeakRefSource(global_name)
self.static_tensor_names.add(tx.output.module_key_name(source.name()))
return VariableTracker.build(tx, tensor_value, source)
def update_list_args(
self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs
):
"""Update the args and kwargs to the traced optimizer call"""
for arg, py_arg in zip(args, py_args):
if isinstance(arg, ListVariable):
assert isinstance(py_arg, list), (
"py_arg should be a list in optimizer variable"
)
for i, val in enumerate(py_arg):
tx.output.side_effects.mutation(arg)
if isinstance(val, torch.Tensor):
arg.items.append(self.wrap_tensor(tx, val))
else:
source = arg.source and GetItemSource(arg.source, i)
arg.items.append(VariableTracker.build(tx, val, source))
def create_finalizer(self, tx):
names_to_delete = self.static_tensor_names
value = self.value
tc = tx.output.tracing_context
def init_finalizer(gm):
def clear_static_tensor_refs():
for name in names_to_delete:
gm._buffers.pop(name, None)
gm._parameters.pop(name, None)
if tc.params_flat:
tc.params_flat.clear()
if tc.params_flat_unwrap_subclasses:
tc.params_flat_unwrap_subclasses.clear()
weakref.finalize(value, clear_static_tensor_refs)
tx.output.add_graph_finalizer(init_finalizer)