Change register_autograd to reflect ordering of setup_context and backward (#124403)

old: `register_autograd(setup_context, backward, /)`
new: `register_autograd(backward, /, *, setup_context=None)`

Motivations:
- We introduce these APIs as "give us a backward and use setup_context
  to save things for backward".
- setup_context isn't always necessary.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124403
Approved by: https://github.com/albanD
ghstack dependencies: #124180, #124200, #124299, #124134, #124199
This commit is contained in:
rzou 2024-04-19 06:31:56 -07:00 committed by PyTorch MergeBot
parent a8e17b2d4d
commit 25c65d6642
5 changed files with 37 additions and 24 deletions

View File

@ -2486,14 +2486,18 @@ class TestCustomOpAPI(TestCase):
return grad * x.cos()
if mode == "function":
torch.library.register_autograd(numpy_sin, setup_context, backward)
torch.library.register_autograd(
numpy_sin, backward, setup_context=setup_context
)
elif mode == "qualname":
torch.library.register_autograd(
"mylib::numpy_sin", setup_context, backward
"mylib::numpy_sin", backward, setup_context=setup_context
)
elif mode == "opoverload":
torch.library.register_autograd(
torch.ops.mylib.numpy_sin.default, setup_context, backward
torch.ops.mylib.numpy_sin.default,
backward,
setup_context=setup_context,
)
x = torch.randn(3, requires_grad=True)
@ -2531,13 +2535,16 @@ class TestCustomOpAPI(TestCase):
if mode == "qualname":
torch.library.register_autograd(
"_torch_testing::sin5", setup_context, backward, lib=lib
"_torch_testing::sin5",
backward,
setup_context=setup_context,
lib=lib,
)
elif mode == "opoverload":
torch.library.register_autograd(
torch.ops._torch_testing.sin5.default,
setup_context,
backward,
setup_context=setup_context,
lib=lib,
)
x = torch.randn(3, requires_grad=True)

View File

@ -8,14 +8,14 @@ from . import utils
class InfoProtocol(Protocol):
_setup_context_fn: Optional[Callable]
_backward_fn: Optional[Callable]
_setup_context_fn: Optional[Callable]
@dataclasses.dataclass
class Info:
_setup_context_fn: Optional[Callable]
_backward_fn: Optional[Callable]
_setup_context_fn: Optional[Callable]
def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable:

View File

@ -341,7 +341,11 @@ class CustomOpDef:
return fn
def register_autograd(
self, setup_context_fn: Optional[Callable], backward_fn: Callable, /
self,
backward: Callable,
/,
*,
setup_context: Optional[Callable] = None,
) -> None:
r"""Register a backward formula for this custom op.
@ -388,7 +392,7 @@ class CustomOpDef:
>>> x, = ctx.saved_tensors
>>> return grad * x.cos()
>>>
>>> numpy_sin.register_autograd(setup_context, backward)
>>> numpy_sin.register_autograd(backward, setup_context=setup_context)
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_sin(x)
@ -404,8 +408,8 @@ class CustomOpDef:
f"a functional operator and register an autograd formula for that."
)
self._backward_fn = backward_fn
self._setup_context_fn = setup_context_fn
self._backward_fn = backward
self._setup_context_fn = setup_context
def _register_to_dispatcher(self) -> None:
lib = self._lib

View File

@ -602,7 +602,7 @@ def register_fake(
return register(func)
def register_autograd(op: _op_identifier, setup_context_fn: Optional[Callable], backward_fn: Callable, /, *, lib=None) -> None:
def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_context: Optional[Callable] = None, lib=None) -> None:
r"""Register a backward formula for this custom op.
In order for an operator to work with autograd, you need to register
@ -648,7 +648,7 @@ def register_autograd(op: _op_identifier, setup_context_fn: Optional[Callable],
>>> x, = ctx.saved_tensors
>>> return grad * x.cos()
>>>
>>> torch.library.register_autograd("mylib::numpy_sin", setup_context, backward)
>>> torch.library.register_autograd("mylib::numpy_sin", backward, setup_context=setup_context)
>>>
>>> x = torch.randn(3, requires_grad=True)
>>> y = numpy_sin(x)
@ -662,7 +662,7 @@ def register_autograd(op: _op_identifier, setup_context_fn: Optional[Callable],
op = op._name
opdef = _maybe_get_opdef(op)
if opdef is not None:
opdef.register_autograd(setup_context_fn, backward_fn)
opdef.register_autograd(backward, setup_context=setup_context)
return
assert isinstance(op, str)
@ -676,7 +676,7 @@ def register_autograd(op: _op_identifier, setup_context_fn: Optional[Callable],
f"a functional operator and register an autograd formula for that."
)
info = _library.autograd.Info(setup_context_fn, backward_fn)
info = _library.autograd.Info(backward, setup_context)
autograd_kernel = _library.autograd.make_autograd_impl(op, info)
namespace, opname = torch._library.utils.parse_namespace(qualname)
if lib is None:

View File

@ -46,7 +46,7 @@ def numpy_cube_backward(ctx, grad_out, grad_dx):
grad_x = numpy_mul(grad_out, dx) + 6 * numpy_mul(grad_dx, x)
return grad_x
numpy_cube.register_autograd(numpy_cube_setup_context, numpy_cube_backward)
numpy_cube.register_autograd(numpy_cube_backward, setup_context=numpy_cube_setup_context)
@torch.library.custom_op("_torch_testing::numpy_mul", mutates_args=())
def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
@ -66,7 +66,7 @@ def numpy_mul_backward(ctx, grad_out):
grad_y = grad_out * x if ctx.needs_input_grad[1] else None
return grad_x, grad_y
numpy_mul.register_autograd(numpy_mul_setup_context, numpy_mul_backward)
numpy_mul.register_autograd(numpy_mul_backward, setup_context=numpy_mul_setup_context)
@torch.library.custom_op("_torch_testing::numpy_sort", mutates_args=())
def numpy_sort(x: Tensor, dim: int) -> Tuple[Tensor, Tensor, Tensor]:
@ -95,7 +95,7 @@ def numpy_sort_backward(ctx, grad_out, grad_ind, grad_ind_inv):
ind, ind_inv = ctx.saved_tensors
return numpy_take(grad_out, ind_inv, ind, ctx.dim), None
numpy_sort.register_autograd(numpy_sort_setup_context, numpy_sort_backward)
numpy_sort.register_autograd(numpy_sort_backward, setup_context=numpy_sort_setup_context)
@torch.library.custom_op("_torch_testing::numpy_take", mutates_args=())
@ -123,7 +123,7 @@ def numpy_take_backward(ctx, grad_out):
grad_x = numpy_take(grad_out, ind_inv, ind, ctx.dim)
return grad_x, None, None, None
numpy_take.register_autograd(numpy_take_setup_context, numpy_take_backward)
numpy_take.register_autograd(numpy_take_backward, setup_context=numpy_take_setup_context)
@torch.library.custom_op("_torch_testing::numpy_nonzero", mutates_args=())
def numpy_nonzero(x: Tensor) -> Tensor:
@ -165,7 +165,7 @@ def numpy_view_copy_setup_context(ctx, inputs, output) -> None:
def numpy_view_copy_backward(ctx, grad_out):
return torch.ops._torch_testing.numpy_view_copy(grad_out, ctx.x_shape), None
numpy_view_copy.register_autograd(numpy_view_copy_setup_context, numpy_view_copy_backward)
numpy_view_copy.register_autograd(numpy_view_copy_backward, setup_context=numpy_view_copy_setup_context)
def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -188,7 +188,7 @@ def _(xs, dim):
assert all(x.dtype == xs[0].dtype for x in xs)
return torch.cat(xs, dim=dim)
def numpy_cat_setup_backward(ctx, inputs, output):
def numpy_cat_setup_context(ctx, inputs, output):
xs, dim = inputs
ctx.dim_sizes = [x.shape[dim] for x in xs]
ctx.dim = dim
@ -201,7 +201,7 @@ def numpy_cat_backward(ctx, grad_out):
grad_xs = torch.ops._torch_testing.numpy_split_copy(grad_out, splits, dim)
return grad_xs, None
numpy_cat.register_autograd(numpy_cat_setup_backward, numpy_cat_backward)
numpy_cat.register_autograd(numpy_cat_backward, setup_context=numpy_cat_setup_context)
def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -228,7 +228,7 @@ def numpy_split_copy_backward(ctx, grad_out):
result = torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim)
return result, None, None
numpy_split_copy.register_autograd(numpy_split_copy_setup_context, numpy_split_copy_backward)
numpy_split_copy.register_autograd(numpy_split_copy_backward, setup_context=numpy_split_copy_setup_context)
def sample_inputs_numpy_split_copy(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -252,7 +252,9 @@ def numpy_split_copy_with_int_setup_context(ctx, inputs, output):
def numpy_split_copy_with_int_backward(ctx, grad_out, _):
return torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim), None, None
numpy_split_copy_with_int.register_autograd(numpy_split_copy_with_int_setup_context, numpy_split_copy_with_int_backward)
numpy_split_copy_with_int.register_autograd(
numpy_split_copy_with_int_backward,
setup_context=numpy_split_copy_with_int_setup_context)
@torch.library.custom_op("_torch_testing::numpy_nms", mutates_args=())
def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor: