mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Earlier, with inline flag we were lifting id-guarded tensors to the inputs to the Fx graph. But this offers no benefit. Main idea behind lifting parameters as inputs was to reuse the compilation units across many instances of the nn-module. However, if we are guarding on the `id`, we are explicitly specializing the compiled artifact to the parameter. This PR installs the parameters back into the graph. The benefit is removal of all pre-graph bytecode to extract the id-guarded tensors from locals/globals. This increases speedup from 1.67x to 1.75x for an internal model that has large number of optimizer parameters. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147824 Approved by: https://github.com/jansel Co-authored-by: Jason Ansel <jansel@meta.com>
409 lines
16 KiB
Python
409 lines
16 KiB
Python
# mypy: ignore-errors
|
|
|
|
"""
|
|
This module implements variable tracking for PyTorch optimizers during Dynamo tracing.
|
|
|
|
The OptimizerVariable class provides specialized handling for optimizer instances by:
|
|
- Optimizing the tracing of expensive optimizer initialization
|
|
- Managing optimizer state and parameter group tracking
|
|
- Handling tensor sources and guards for optimizer state tensors
|
|
- Supporting CUDA graph execution through static tensor address management
|
|
- Providing special handling for parameter gradients and optimizer state tensors
|
|
|
|
Key features include:
|
|
- Efficient initialization tracing via _init_group optimization
|
|
- Automatic marking of optimizer state tensors as static for CUDA graphs
|
|
- Proper source tracking for parameter groups, gradients, and state tensors
|
|
- Guard installation for optimizer state structure
|
|
- Support for both CPU and GPU tensor handling
|
|
- Cleanup of static tensor references via finalizers
|
|
|
|
The module integrates with Dynamo's broader tracing system while providing
|
|
optimizer-specific optimizations and safety guarantees.
|
|
"""
|
|
|
|
import logging
|
|
import weakref
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch._logging import getArtifactLogger
|
|
from torch.utils._pytree import tree_map_only
|
|
|
|
from ..guards import GuardBuilder, install_guard
|
|
from ..source import (
|
|
AttrSource,
|
|
ConstDictKeySource,
|
|
DictGetItemSource,
|
|
GetItemSource,
|
|
GlobalWeakRefSource,
|
|
GradSource,
|
|
)
|
|
from ..utils import GLOBAL_KEY_PREFIX
|
|
from .base import VariableTracker
|
|
from .constant import ConstantVariable
|
|
from .dicts import ConstDictVariable
|
|
from .lists import ListVariable
|
|
from .misc import GetAttrVariable
|
|
from .user_defined import UserDefinedObjectVariable
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._dynamo.symbolic_convert import InstructionTranslator
|
|
|
|
|
|
class ArgMappingException(Exception):
|
|
pass
|
|
|
|
|
|
class GuardInstallException(Exception):
|
|
pass
|
|
|
|
|
|
perf_hint_log = getArtifactLogger(__name__, "perf_hints")
|
|
|
|
|
|
def _is_static_for_cudagraphs(x):
|
|
from torch._inductor.cudagraph_trees import get_manager
|
|
|
|
if x.is_cuda:
|
|
manager = get_manager(x.device.index, False)
|
|
is_static_address = torch._dynamo.utils.get_static_address_type(x) is not None
|
|
if manager:
|
|
return (
|
|
is_static_address
|
|
or manager.current_node._is_cuda_graph_recorded_tensor(x)
|
|
)
|
|
else:
|
|
return is_static_address
|
|
else:
|
|
# Don't print a warning for non-cuda tensors
|
|
return True
|
|
|
|
|
|
class OptimizerVariable(UserDefinedObjectVariable):
|
|
_nonvar_fields = {
|
|
"grad_to_source",
|
|
"tensor_to_source",
|
|
"static_tensor_names",
|
|
*UserDefinedObjectVariable._nonvar_fields,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
value,
|
|
grad_to_source=None,
|
|
static_tensor_names=None,
|
|
tensor_to_source=None,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(value, **kwargs)
|
|
self.grad_to_source = grad_to_source or {}
|
|
self.tensor_to_source = tensor_to_source or {}
|
|
self.static_tensor_names = static_tensor_names or set()
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "list[VariableTracker]",
|
|
kwargs: "dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
"""This is an optimization to avoid tracing the very slow initialization of the optimizer"""
|
|
if name == "_init_group":
|
|
try:
|
|
self.graph_break_if_pending_mutation(tx)
|
|
self.move_step_if_cpu()
|
|
py_args, py_kwargs = self.get_python_args(*args, **kwargs)
|
|
ret_val = self.value._init_group(*py_args, **py_kwargs)
|
|
self.map_sources_and_install_guards(tx)
|
|
self.update_list_args(tx, args, kwargs, py_args, py_kwargs)
|
|
# stash a weak_ptr to optimizer to invalidate code
|
|
# if the optimizer object dies
|
|
mangled_name = f"__optimizer_{id(self.value)}"
|
|
tx.store_global_weakref_by_id(mangled_name, self.value)
|
|
self.create_finalizer(tx)
|
|
|
|
# This is currently safe only because the only actual `ret_val`s returned
|
|
# by the `_init_group` of existing optimizers are properties that are invariant
|
|
# to the input tensors (e.g. dtype, layout). Changing these would trigger a
|
|
# recompilation and hence never result in the wrong specialization of `ret_val`.
|
|
return ConstantVariable.create(ret_val)
|
|
except (ArgMappingException, GuardInstallException) as _:
|
|
# trace normally if we can't map args or install guards correctly
|
|
pass
|
|
|
|
return super().call_method(tx, name, args, kwargs)
|
|
|
|
def var_getattr(self, tx: "InstructionTranslator", name):
|
|
# Note: this allows us to intercept the call in call_method
|
|
# in the typical case, we return a UserMethodVariable
|
|
# which will directly inline
|
|
if name in ("_init_group", "step"):
|
|
return GetAttrVariable(self, name, source=AttrSource(self.source, name))
|
|
|
|
if name == "param_groups":
|
|
from ..decorators import mark_static_address
|
|
|
|
for group in self.value.param_groups:
|
|
for p in group["params"]:
|
|
mark_static_address(p)
|
|
|
|
self._set_capturable(tx)
|
|
|
|
return super().var_getattr(tx, name)
|
|
|
|
def graph_break_if_pending_mutation(self, tx):
|
|
# If there are pending mutations on a parameter (due to using closure)
|
|
# then we need to graph break to allow the python version of the parameter
|
|
# to update, so that running _init_group will initialize the states with
|
|
# the correct values
|
|
for g in self.value.param_groups:
|
|
for p in g["params"]:
|
|
side_effects = tx.output.side_effects
|
|
variable = side_effects.id_to_variable.get(id(p), None)
|
|
if variable and side_effects.has_pending_mutation(variable):
|
|
from ..exc import Unsupported
|
|
|
|
raise Unsupported("Pending mutation on parameter")
|
|
|
|
def _set_capturable(self, tx):
|
|
from . import LazyVariableTracker
|
|
|
|
# We only set capturable if params are on cuda
|
|
# and the state is not initialized
|
|
def safe_to_set_capturable(group):
|
|
all_uninitialized = True
|
|
all_gpu = True
|
|
|
|
for p in group.get("params", []):
|
|
all_gpu &= p.is_cuda or p.is_xpu
|
|
all_uninitialized &= p not in self.value.state
|
|
|
|
return "capturable" in group and all_uninitialized and all_gpu
|
|
|
|
# track indices to not set so we don't need to
|
|
# in the variable tracker realize the whole state
|
|
# we handle guarding the state specially
|
|
for group in self.value.param_groups:
|
|
if safe_to_set_capturable(group):
|
|
group["capturable"] = True
|
|
|
|
source = self.source and AttrSource(self.source, "param_groups")
|
|
param_groups_vt = LazyVariableTracker.realize_all(
|
|
VariableTracker.build(tx, self.value.param_groups, source)
|
|
)
|
|
for param_group_vt in param_groups_vt.items:
|
|
key = ConstDictVariable._HashableTracker(
|
|
ConstantVariable.create("capturable")
|
|
)
|
|
param_group_vt.items[key] = ConstantVariable.create(True)
|
|
|
|
def get_python_args(self, *args, **kwargs):
|
|
"""Get python values equivalent to the variable tracker args"""
|
|
|
|
def map_arg(arg):
|
|
if isinstance(arg, ConstantVariable):
|
|
return arg.as_python_constant()
|
|
elif isinstance(arg, ListVariable) and not arg.items:
|
|
return []
|
|
elif (
|
|
isinstance(arg, ConstDictVariable)
|
|
and isinstance(arg.source, GetItemSource)
|
|
and isinstance(arg.source.base, AttrSource)
|
|
and arg.source.base.member == "param_groups"
|
|
):
|
|
return self.value.param_groups[arg.source.index]
|
|
|
|
raise ArgMappingException
|
|
|
|
new_args = [map_arg(arg) for arg in args]
|
|
new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}
|
|
|
|
return new_args, new_kwargs
|
|
|
|
# If users load an old state dictionary,
|
|
# it's possible that step could be on the cpu
|
|
# if this is the case, move it to the GPU
|
|
# corresponding to the parameter
|
|
# in most cases this is a no-op because the state is empty
|
|
def move_step_if_cpu(self):
|
|
for p, state in self.value.state.items():
|
|
if "step" in state and state["step"].is_cpu:
|
|
state["step"] = state["step"].to(p.device)
|
|
|
|
def map_sources_and_install_guards(self, tx):
|
|
from ..decorators import mark_static_address
|
|
from .lazy import LazyVariableTracker
|
|
|
|
self.grad_to_source = {}
|
|
self.tensor_to_source = {}
|
|
|
|
# Tracing the _init_group is expensive. But we still have to insert the
|
|
# necessary guards for _init_group. So, we manually handle insertion of
|
|
# guards. We also want to mark all the tensors inside the state dict to
|
|
# be static address.
|
|
|
|
# Mark all the tensors in the state dict to be static address. This has
|
|
# to be done first because the variable builder relies on the static
|
|
# address annotation.
|
|
def mark_static(x):
|
|
mark_static_address(x)
|
|
|
|
tree_map_only(torch.Tensor, mark_static, self.value.state)
|
|
|
|
# Recursively realize the variable trackers for optim.state and
|
|
# optim.param_groups, which recursively install the necessary guards.
|
|
params_groups_source = self.source and AttrSource(self.source, "param_groups")
|
|
param_groups_vt = LazyVariableTracker.realize_all(
|
|
VariableTracker.build(tx, self.value.param_groups, params_groups_source)
|
|
)
|
|
|
|
state_source = self.source and AttrSource(self.source, "state")
|
|
|
|
state_vt = VariableTracker.build(tx, self.value.state, state_source)
|
|
|
|
# We need to realize the top level state dict to populate
|
|
# the guard locals
|
|
state_vt.realize()
|
|
tx.output.guard_on_key_order.add(state_source.name())
|
|
|
|
# Populate self.grad_to_source and self.tensor_to_source so that we can
|
|
# manually update_list_args
|
|
for group, group_vt in zip(self.value.param_groups, param_groups_vt.items):
|
|
# we assume here that all params within a param group
|
|
# are initialized similarly
|
|
if len(group["params"]) > 0:
|
|
for param in group["params"]:
|
|
if param.grad is not None:
|
|
key_index = None
|
|
for i, k in enumerate(self.value.state.keys()):
|
|
if k is param:
|
|
key_index = i
|
|
break
|
|
if key_index:
|
|
LazyVariableTracker.realize_all(
|
|
VariableTracker.build(
|
|
tx,
|
|
self.value.state[param],
|
|
DictGetItemSource(
|
|
state_source,
|
|
ConstDictKeySource(state_source, key_index),
|
|
),
|
|
)
|
|
)
|
|
break
|
|
|
|
params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params"))
|
|
all_static = True
|
|
non_static_grads = []
|
|
for p_ind, (p, p_vt) in enumerate(
|
|
zip(group["params"], params_vt.unpack_var_sequence(tx))
|
|
):
|
|
param_source = p_vt.source
|
|
self.tensor_to_source[p] = param_source
|
|
grad_source = GradSource(
|
|
param_source,
|
|
"grad",
|
|
)
|
|
|
|
if p.grad is not None:
|
|
self.grad_to_source[p.grad] = grad_source
|
|
if not _is_static_for_cudagraphs(p.grad):
|
|
all_static = False
|
|
non_static_grads.append(grad_source)
|
|
else:
|
|
install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH))
|
|
|
|
# Note: to avoid spam logs only warn if perf hint artifact is enabled
|
|
# (NB: artifacts are only enabled at the debug or warning level)
|
|
if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG):
|
|
non_static_grads = [src.name() for src in non_static_grads]
|
|
perf_hint_log.warning(
|
|
(
|
|
"Grad tensors %s will be copied during cudagraphs execution."
|
|
"If using cudagraphs and the grad tensor addresses will be the same across runs,"
|
|
" use torch._dynamo.decorators.mark_static_address to elide this copy.",
|
|
),
|
|
non_static_grads,
|
|
)
|
|
|
|
# We have to again iterate over the state dict to collect the
|
|
# tensor_to_source dict. This is used for the finalizer.
|
|
for idx, (p, value) in enumerate(self.value.state.items()):
|
|
p_state_source = DictGetItemSource(
|
|
state_source, ConstDictKeySource(state_source, idx)
|
|
)
|
|
tx.output.guard_on_key_order.add(p_state_source.name())
|
|
for inner_idx, (k, v) in enumerate(value.items()):
|
|
if (
|
|
isinstance(v, torch.Tensor)
|
|
and v not in self.grad_to_source
|
|
and v not in self.tensor_to_source
|
|
):
|
|
self.tensor_to_source[v] = DictGetItemSource(
|
|
p_state_source, ConstDictKeySource(p_state_source, inner_idx)
|
|
)
|
|
|
|
def wrap_tensor(self, tx: "InstructionTranslator", tensor_value):
|
|
"""Wrap state tensor in a TensorVariable"""
|
|
from ..decorators import mark_static_address
|
|
|
|
# If we have a source for a tensor already use it,
|
|
# if we have not seen a tensor before, stash and use a
|
|
# global weak ref source, since it must be an optimizer tensor
|
|
# that we have missed
|
|
|
|
if tensor_value in self.tensor_to_source:
|
|
# mark these tensors as static for cudagraphs
|
|
mark_static_address(tensor_value)
|
|
source = self.tensor_to_source[tensor_value]
|
|
self.static_tensor_names.add(tx.output.module_key_name(source.name()))
|
|
elif tensor_value in self.grad_to_source:
|
|
source = self.grad_to_source[tensor_value]
|
|
else:
|
|
# mark these tensors as static for cudagraphs
|
|
mark_static_address(tensor_value)
|
|
|
|
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
|
|
source = GlobalWeakRefSource(global_name)
|
|
self.static_tensor_names.add(tx.output.module_key_name(source.name()))
|
|
|
|
return VariableTracker.build(tx, tensor_value, source)
|
|
|
|
def update_list_args(
|
|
self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs
|
|
):
|
|
"""Update the args and kwargs to the traced optimizer call"""
|
|
for arg, py_arg in zip(args, py_args):
|
|
if isinstance(arg, ListVariable):
|
|
assert isinstance(py_arg, list), (
|
|
"py_arg should be a list in optimizer variable"
|
|
)
|
|
for i, val in enumerate(py_arg):
|
|
tx.output.side_effects.mutation(arg)
|
|
if isinstance(val, torch.Tensor):
|
|
arg.items.append(self.wrap_tensor(tx, val))
|
|
else:
|
|
source = arg.source and GetItemSource(arg.source, i)
|
|
arg.items.append(VariableTracker.build(tx, val, source))
|
|
|
|
def create_finalizer(self, tx):
|
|
names_to_delete = self.static_tensor_names
|
|
value = self.value
|
|
tc = tx.output.tracing_context
|
|
|
|
def init_finalizer(gm):
|
|
def clear_static_tensor_refs():
|
|
for name in names_to_delete:
|
|
gm._buffers.pop(name, None)
|
|
gm._parameters.pop(name, None)
|
|
if tc.params_flat:
|
|
tc.params_flat.clear()
|
|
if tc.params_flat_unwrap_subclasses:
|
|
tc.params_flat_unwrap_subclasses.clear()
|
|
|
|
weakref.finalize(value, clear_static_tensor_refs)
|
|
|
|
tx.output.add_graph_finalizer(init_finalizer)
|