pytorch/torch/_library/autograd.py
rzou 25c65d6642 Change register_autograd to reflect ordering of setup_context and backward (#124403)
old: `register_autograd(setup_context, backward, /)`
new: `register_autograd(backward, /, *, setup_context=None)`

Motivations:
- We introduce these APIs as "give us a backward and use setup_context
  to save things for backward".
- setup_context isn't always necessary.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124403
Approved by: https://github.com/albanD
ghstack dependencies: #124180, #124200, #124299, #124134, #124199
2024-04-19 17:56:30 +00:00

195 lines
6.7 KiB
Python

import dataclasses
from typing import Any, Callable, Optional, Protocol
from .. import _C, _ops, autograd, Tensor
from ..utils import _pytree
from . import utils
class InfoProtocol(Protocol):
_backward_fn: Optional[Callable]
_setup_context_fn: Optional[Callable]
@dataclasses.dataclass
class Info:
_backward_fn: Optional[Callable]
_setup_context_fn: Optional[Callable]
def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable:
name: str = f"GeneratedBackwardFor_{op._namespace}_{op._opname}_{op._overloadname}"
saved_keyset = None
def forward(ctx, *args):
with _C._AutoDispatchBelowAutograd():
nonlocal saved_keyset
keyset = saved_keyset
assert keyset is not None, "Should have been set by autograd_impl"
saved_keyset = None
result = op.redispatch(keyset & _C._after_autograd_keyset, *args)
if info._setup_context_fn:
info._setup_context_fn(ctx, args, result)
return result
def backward(ctx, *grads):
if info._backward_fn:
result = info._backward_fn(ctx, *grads)
return result
raise RuntimeError(
f"Trying to backward through {op} but no autograd "
f"formula was registered. "
f"Please use register_autograd to add one."
)
Generated = type(
name,
(autograd.Function,),
{
"forward": staticmethod(forward),
"backward": staticmethod(backward),
},
)
schema = op._schema
if any(
utils.is_tensorlist_like_type(a.type)
for a in (*schema.arguments, *schema.returns)
):
Generated = supports_tensorlist(Generated)
def autograd_impl(keyset, *args):
# We set a nonlocal to ferry keyset from here to the forward.
# This supports recursive calls (we implement the forward carefully so
# that it'll read saved_keyset before making a recursive call to the op).
nonlocal saved_keyset
assert saved_keyset is None
saved_keyset = keyset
result = Generated.apply(*args) # type: ignore[attr-defined]
return result
return autograd_impl
def supports_tensorlist(cls: Any) -> Any:
"""Allows a given autograd.Function class to support List[Tensor] inputs/outputs.
Regular autograd.Function has a constraint that it only directly supports autograd for
Tensors. Applying @supports_tensorlist enables an autograd.Function to support
autograd for List[Tensor] inputs and outputs.
"""
# NB: All calls to the autograd.Function.apply shares these variables
# We assume that only one call to .apply happens at a time. This means that
# you cannot call the autograd.Function recursively (e.g. from its own forward).
input_spec: Optional[spec_t] = None
output_spec: Optional[spec_t] = None
result_is_tuple = None
orig_forward = cls.forward
orig_backward = cls.backward
orig_apply = cls.apply
def new_forward(ctx, *args):
if input_spec is None:
raise NotImplementedError(
"NYI: calling supports_tensorlist autograd.Function.forward directly. "
"You should probably be calling .apply instead. "
"Please file an issue if not."
)
args = unflatten(list(args), input_spec)
result = orig_forward(ctx, *args)
nonlocal output_spec
nonlocal result_is_tuple
result_is_tuple = isinstance(result, tuple)
if not result_is_tuple:
result = (result,)
nonlocal output_spec
flat_result, output_spec = flatten(result, not_list_of_tensor)
# Save the input_spec/output_spec for backward because another call to
# .apply will override the nonlocals.
if hasattr(ctx, "_pt_metadata"):
raise RuntimeError(
"Please don't set ctx._pt_metadata; PyTorch uses it to store info"
)
ctx._pt_metadata = (input_spec, output_spec)
return tuple(flat_result)
def new_backward(ctx, *grads):
if not hasattr(ctx, "_pt_metadata"):
raise NotImplementedError(
"NYI: calling supports_tensorlist autograd.Function.backward directly. "
"This will automatically get called by PyTorch autograd. "
"Please file an issue if you need this."
)
input_spec, output_spec = ctx._pt_metadata
grads = unflatten(list(grads), output_spec)
grad_inputs = orig_backward(ctx, *grads)
if not isinstance(grad_inputs, tuple):
grad_inputs = (grad_inputs,)
# Assume that any Nones in the backward are Tensors.
# If the forward has an arg that is [1, 2, 3], the backward should
# return None as the grad.
# If the forward has an arg that is [tensor, tensor], the backward
# may return [None, None], [grad, None], [None, grad], or [grad, grad].
flat_grad_inputs, grad_inputs_spec = flatten(
grad_inputs, not_list_of_optional_tensor
)
if grad_inputs_spec != input_spec:
raise RuntimeError(
f"Expected the return from backward to be of the same structure "
f"as the inputs. Got: {grad_inputs_spec} (return from backward), "
f"{input_spec} (inputs)"
)
return tuple(flat_grad_inputs)
def new_apply(*args):
nonlocal input_spec
if input_spec is not None:
raise NotImplementedError(
"NYI: Recursive call to autograd.Function decorated with "
"`supports_tensorlist`. Please file an issue."
)
try:
flat_args, input_spec = flatten(args, is_leaf=not_list_of_tensor)
result = orig_apply(*flat_args) # type: ignore[misc]
finally:
input_spec = None
assert output_spec is not None
result = unflatten(list(result), output_spec)
if not result_is_tuple:
assert isinstance(result, tuple)
assert len(result) == 1
return result[0]
return result
cls.forward = new_forward
cls.backward = new_backward
cls.apply = new_apply
return cls
def not_list_of_tensor(tree):
if isinstance(tree, tuple):
return False
if isinstance(tree, list):
return any(not isinstance(l, Tensor) for l in tree)
return True
def not_list_of_optional_tensor(tree):
if isinstance(tree, tuple):
return False
if isinstance(tree, list):
return any(l is not None and not isinstance(l, Tensor) for l in tree)
return True
flatten = _pytree.tree_flatten
unflatten = _pytree.tree_unflatten
spec_t = _pytree.TreeSpec