mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
We will always proxy autograd.Function nodes in compiled autograd's initial graph capture (previously there was an option to proxy vs trace into the autograd.Function) We have some requirements for the AOTBackward. Compiled Autograd runs accumulate grad reordering passes on the AOTBackward graph directly after the initial graph capture, so we can't just proxy a single node for it. Instead, we: - proxy the AOTBackward prologue function into the CA graph - copy-paste the AOTBackward graph into the CA graph - trace directly through the epilogue (the traced nodes go into the CA graph). Tracing through the epilogue is safe (assuming no Tensor subclasses) because the only thing the epilogue does is drop some outputs. The Tensor subclass situation was already broken so this doesn't regress anything but this PR sets it up to be fixed (in a followup, where we will proxy "make_subclass" calls into the graph from the epilogue). Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/143405 Approved by: https://github.com/jansel, https://github.com/xmfan ghstack dependencies: #143296, #143304, #143387
168 lines
4.7 KiB
Python
168 lines
4.7 KiB
Python
# This module contains functions that *will be allowed* by dynamo
|
|
|
|
import functools
|
|
import warnings
|
|
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
|
from typing_extensions import deprecated, ParamSpec
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
|
|
|
|
try:
|
|
import numpy as np
|
|
except ModuleNotFoundError:
|
|
np = None # type: ignore[assignment]
|
|
|
|
_P = ParamSpec("_P")
|
|
_R = TypeVar("_R")
|
|
|
|
if TYPE_CHECKING:
|
|
# TorchScript does not support `@deprecated`
|
|
# This is a workaround to avoid breaking TorchScript
|
|
@deprecated(
|
|
"`torch._dynamo.external_utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.",
|
|
category=FutureWarning,
|
|
)
|
|
def is_compiling() -> bool:
|
|
return torch.compiler.is_compiling()
|
|
|
|
else:
|
|
|
|
def is_compiling() -> bool:
|
|
"""
|
|
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
|
|
"""
|
|
# NOTE: With `@torch.compile(backend="eager")`, torch._dynamo.is_compiling() will get traced
|
|
# and return true. torch.compiler.is_compiling() is skipped and will return false.
|
|
return torch.compiler.is_compiling()
|
|
|
|
|
|
def wrap_inline(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
|
"""
|
|
Create an extra frame around fn that is not in skipfiles.
|
|
"""
|
|
|
|
@functools.wraps(fn)
|
|
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
return fn(*args, **kwargs)
|
|
|
|
return inner
|
|
|
|
|
|
def call_hook(
|
|
hook: Callable[..., Optional[torch.Tensor]], *args: Any, **kwargs: Any
|
|
) -> torch.Tensor:
|
|
"""
|
|
Used by compiled autograd to handle hook returning None.
|
|
"""
|
|
result = hook(*args)
|
|
if result is None:
|
|
return args[0]
|
|
elif kwargs.get("hook_type") == "post_acc_grad_hook":
|
|
raise RuntimeError("Tensor post accumulate grad hooks should return None.")
|
|
return result
|
|
|
|
|
|
def wrap_numpy(f: Callable[_P, _R]) -> Callable[_P, _R]:
|
|
r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
|
|
from ``torch.Tensor``s to ``torch.Tensor``s.
|
|
"""
|
|
if not np:
|
|
return f
|
|
|
|
@functools.wraps(f)
|
|
def wrap(*args: _P.args, **kwargs: _P.kwargs) -> pytree.PyTree:
|
|
args, kwargs = pytree.tree_map_only(
|
|
torch.Tensor, lambda x: x.numpy(), (args, kwargs)
|
|
)
|
|
out = f(*args, **kwargs)
|
|
return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
|
|
|
|
return wrap
|
|
|
|
|
|
class FakeBackwardCFunction:
|
|
def __init__(
|
|
self,
|
|
real: torch.autograd.function.BackwardCFunction,
|
|
saved_tensors: list[torch.Tensor],
|
|
) -> None:
|
|
self.real = real
|
|
self.saved_tensors = saved_tensors
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
if name == "saved_variables":
|
|
warnings.warn(
|
|
"'saved_variables' is deprecated; use 'saved_tensors'",
|
|
DeprecationWarning,
|
|
)
|
|
return self.saved_tensors
|
|
|
|
return getattr(self.real, name)
|
|
|
|
|
|
def call_backward(
|
|
backward_c_function: torch.autograd.function.BackwardCFunction,
|
|
saved_tensors: list[torch.Tensor],
|
|
*args: Any,
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
|
fake = FakeBackwardCFunction(backward_c_function, saved_tensors)
|
|
grads = fake._forward_cls.backward(fake, *args) # type: ignore[attr-defined]
|
|
|
|
if not isinstance(grads, tuple):
|
|
grads = (grads,)
|
|
|
|
return grads
|
|
|
|
|
|
def normalize_as_list(x: Any) -> list[Any]:
|
|
if isinstance(x, tuple):
|
|
return list(x)
|
|
elif isinstance(x, list):
|
|
return x
|
|
return [x]
|
|
|
|
|
|
def untyped_storage_size(x: torch.Tensor) -> int:
|
|
return x.untyped_storage().size()
|
|
|
|
|
|
class FakeCompiledAutogradEngine:
|
|
@staticmethod
|
|
def queue_callback(
|
|
final_callbacks: list[Callable[[], None]], cb: Callable[[], None]
|
|
) -> None:
|
|
final_callbacks.append(cb)
|
|
|
|
@staticmethod
|
|
def exec_final_callbacks(final_callbacks: list[Callable[[], None]]) -> None:
|
|
i = 0
|
|
while i < len(final_callbacks):
|
|
cb = final_callbacks[i]
|
|
cb()
|
|
i += 1
|
|
final_callbacks.clear()
|
|
|
|
@staticmethod
|
|
def _exec_final_callbacks_stub() -> None:
|
|
pass
|
|
|
|
|
|
def call_hook_from_backward_state(
|
|
*args: Any, bw_state: Any, hook_name: str, **kwargs: Any
|
|
) -> Any:
|
|
return getattr(bw_state, hook_name)(*args, **kwargs)
|
|
|
|
|
|
def call_module_hooks_from_backward_state(
|
|
_: Any, result: Any, *args: Any, bw_state: Any, hooks_name: str, module_name: str
|
|
) -> Any:
|
|
module = getattr(bw_state, module_name)
|
|
hooks = getattr(bw_state, hooks_name)
|
|
for hook in hooks:
|
|
new_result = hook(module, result, *args)
|
|
if new_result is not None:
|
|
result = new_result
|
|
return result
|