pytorch/torch/_dynamo/external_utils.py
rzou ec820fe57c [compiled autograd] Always proxy autograd.Function nodes; handle AOT backwards (#143405)
We will always proxy autograd.Function nodes in compiled autograd's
initial graph capture (previously there was an
option to proxy vs trace into the autograd.Function)

We have some requirements for the AOTBackward. Compiled Autograd runs
accumulate grad reordering passes on the AOTBackward graph directly
after the initial graph capture, so we can't just proxy a single node for it.

Instead, we:
- proxy the AOTBackward prologue function into the CA graph
- copy-paste the AOTBackward graph into the CA graph
- trace directly through the epilogue (the traced nodes go into the CA
  graph).

Tracing through the epilogue is safe (assuming no Tensor subclasses)
because the only thing the epilogue does is drop some outputs. The
Tensor subclass situation was already broken so this doesn't regress
anything but this PR sets it up to be fixed (in a followup, where we
will proxy "make_subclass" calls into the graph from the epilogue).

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143405
Approved by: https://github.com/jansel, https://github.com/xmfan
ghstack dependencies: #143296, #143304, #143387
2025-01-22 21:50:56 +00:00

168 lines
4.7 KiB
Python

# This module contains functions that *will be allowed* by dynamo
import functools
import warnings
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import deprecated, ParamSpec
import torch
import torch.utils._pytree as pytree
try:
import numpy as np
except ModuleNotFoundError:
np = None # type: ignore[assignment]
_P = ParamSpec("_P")
_R = TypeVar("_R")
if TYPE_CHECKING:
# TorchScript does not support `@deprecated`
# This is a workaround to avoid breaking TorchScript
@deprecated(
"`torch._dynamo.external_utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.",
category=FutureWarning,
)
def is_compiling() -> bool:
return torch.compiler.is_compiling()
else:
def is_compiling() -> bool:
"""
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
"""
# NOTE: With `@torch.compile(backend="eager")`, torch._dynamo.is_compiling() will get traced
# and return true. torch.compiler.is_compiling() is skipped and will return false.
return torch.compiler.is_compiling()
def wrap_inline(fn: Callable[_P, _R]) -> Callable[_P, _R]:
"""
Create an extra frame around fn that is not in skipfiles.
"""
@functools.wraps(fn)
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
return fn(*args, **kwargs)
return inner
def call_hook(
hook: Callable[..., Optional[torch.Tensor]], *args: Any, **kwargs: Any
) -> torch.Tensor:
"""
Used by compiled autograd to handle hook returning None.
"""
result = hook(*args)
if result is None:
return args[0]
elif kwargs.get("hook_type") == "post_acc_grad_hook":
raise RuntimeError("Tensor post accumulate grad hooks should return None.")
return result
def wrap_numpy(f: Callable[_P, _R]) -> Callable[_P, _R]:
r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
from ``torch.Tensor``s to ``torch.Tensor``s.
"""
if not np:
return f
@functools.wraps(f)
def wrap(*args: _P.args, **kwargs: _P.kwargs) -> pytree.PyTree:
args, kwargs = pytree.tree_map_only(
torch.Tensor, lambda x: x.numpy(), (args, kwargs)
)
out = f(*args, **kwargs)
return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
return wrap
class FakeBackwardCFunction:
def __init__(
self,
real: torch.autograd.function.BackwardCFunction,
saved_tensors: list[torch.Tensor],
) -> None:
self.real = real
self.saved_tensors = saved_tensors
def __getattr__(self, name: str) -> Any:
if name == "saved_variables":
warnings.warn(
"'saved_variables' is deprecated; use 'saved_tensors'",
DeprecationWarning,
)
return self.saved_tensors
return getattr(self.real, name)
def call_backward(
backward_c_function: torch.autograd.function.BackwardCFunction,
saved_tensors: list[torch.Tensor],
*args: Any,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
fake = FakeBackwardCFunction(backward_c_function, saved_tensors)
grads = fake._forward_cls.backward(fake, *args) # type: ignore[attr-defined]
if not isinstance(grads, tuple):
grads = (grads,)
return grads
def normalize_as_list(x: Any) -> list[Any]:
if isinstance(x, tuple):
return list(x)
elif isinstance(x, list):
return x
return [x]
def untyped_storage_size(x: torch.Tensor) -> int:
return x.untyped_storage().size()
class FakeCompiledAutogradEngine:
@staticmethod
def queue_callback(
final_callbacks: list[Callable[[], None]], cb: Callable[[], None]
) -> None:
final_callbacks.append(cb)
@staticmethod
def exec_final_callbacks(final_callbacks: list[Callable[[], None]]) -> None:
i = 0
while i < len(final_callbacks):
cb = final_callbacks[i]
cb()
i += 1
final_callbacks.clear()
@staticmethod
def _exec_final_callbacks_stub() -> None:
pass
def call_hook_from_backward_state(
*args: Any, bw_state: Any, hook_name: str, **kwargs: Any
) -> Any:
return getattr(bw_state, hook_name)(*args, **kwargs)
def call_module_hooks_from_backward_state(
_: Any, result: Any, *args: Any, bw_state: Any, hooks_name: str, module_name: str
) -> Any:
module = getattr(bw_state, module_name)
hooks = getattr(bw_state, hooks_name)
for hook in hooks:
new_result = hook(module, result, *args)
if new_result is not None:
result = new_result
return result