mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Change register_autograd to reflect ordering of setup_context and backward (#124403)
old: `register_autograd(setup_context, backward, /)` new: `register_autograd(backward, /, *, setup_context=None)` Motivations: - We introduce these APIs as "give us a backward and use setup_context to save things for backward". - setup_context isn't always necessary. Test Plan: - tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/124403 Approved by: https://github.com/albanD ghstack dependencies: #124180, #124200, #124299, #124134, #124199
This commit is contained in:
parent
a8e17b2d4d
commit
25c65d6642
|
|
@ -2486,14 +2486,18 @@ class TestCustomOpAPI(TestCase):
|
|||
return grad * x.cos()
|
||||
|
||||
if mode == "function":
|
||||
torch.library.register_autograd(numpy_sin, setup_context, backward)
|
||||
torch.library.register_autograd(
|
||||
numpy_sin, backward, setup_context=setup_context
|
||||
)
|
||||
elif mode == "qualname":
|
||||
torch.library.register_autograd(
|
||||
"mylib::numpy_sin", setup_context, backward
|
||||
"mylib::numpy_sin", backward, setup_context=setup_context
|
||||
)
|
||||
elif mode == "opoverload":
|
||||
torch.library.register_autograd(
|
||||
torch.ops.mylib.numpy_sin.default, setup_context, backward
|
||||
torch.ops.mylib.numpy_sin.default,
|
||||
backward,
|
||||
setup_context=setup_context,
|
||||
)
|
||||
|
||||
x = torch.randn(3, requires_grad=True)
|
||||
|
|
@ -2531,13 +2535,16 @@ class TestCustomOpAPI(TestCase):
|
|||
|
||||
if mode == "qualname":
|
||||
torch.library.register_autograd(
|
||||
"_torch_testing::sin5", setup_context, backward, lib=lib
|
||||
"_torch_testing::sin5",
|
||||
backward,
|
||||
setup_context=setup_context,
|
||||
lib=lib,
|
||||
)
|
||||
elif mode == "opoverload":
|
||||
torch.library.register_autograd(
|
||||
torch.ops._torch_testing.sin5.default,
|
||||
setup_context,
|
||||
backward,
|
||||
setup_context=setup_context,
|
||||
lib=lib,
|
||||
)
|
||||
x = torch.randn(3, requires_grad=True)
|
||||
|
|
|
|||
|
|
@ -8,14 +8,14 @@ from . import utils
|
|||
|
||||
|
||||
class InfoProtocol(Protocol):
|
||||
_setup_context_fn: Optional[Callable]
|
||||
_backward_fn: Optional[Callable]
|
||||
_setup_context_fn: Optional[Callable]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Info:
|
||||
_setup_context_fn: Optional[Callable]
|
||||
_backward_fn: Optional[Callable]
|
||||
_setup_context_fn: Optional[Callable]
|
||||
|
||||
|
||||
def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable:
|
||||
|
|
|
|||
|
|
@ -341,7 +341,11 @@ class CustomOpDef:
|
|||
return fn
|
||||
|
||||
def register_autograd(
|
||||
self, setup_context_fn: Optional[Callable], backward_fn: Callable, /
|
||||
self,
|
||||
backward: Callable,
|
||||
/,
|
||||
*,
|
||||
setup_context: Optional[Callable] = None,
|
||||
) -> None:
|
||||
r"""Register a backward formula for this custom op.
|
||||
|
||||
|
|
@ -388,7 +392,7 @@ class CustomOpDef:
|
|||
>>> x, = ctx.saved_tensors
|
||||
>>> return grad * x.cos()
|
||||
>>>
|
||||
>>> numpy_sin.register_autograd(setup_context, backward)
|
||||
>>> numpy_sin.register_autograd(backward, setup_context=setup_context)
|
||||
>>>
|
||||
>>> x = torch.randn(3, requires_grad=True)
|
||||
>>> y = numpy_sin(x)
|
||||
|
|
@ -404,8 +408,8 @@ class CustomOpDef:
|
|||
f"a functional operator and register an autograd formula for that."
|
||||
)
|
||||
|
||||
self._backward_fn = backward_fn
|
||||
self._setup_context_fn = setup_context_fn
|
||||
self._backward_fn = backward
|
||||
self._setup_context_fn = setup_context
|
||||
|
||||
def _register_to_dispatcher(self) -> None:
|
||||
lib = self._lib
|
||||
|
|
|
|||
|
|
@ -602,7 +602,7 @@ def register_fake(
|
|||
return register(func)
|
||||
|
||||
|
||||
def register_autograd(op: _op_identifier, setup_context_fn: Optional[Callable], backward_fn: Callable, /, *, lib=None) -> None:
|
||||
def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_context: Optional[Callable] = None, lib=None) -> None:
|
||||
r"""Register a backward formula for this custom op.
|
||||
|
||||
In order for an operator to work with autograd, you need to register
|
||||
|
|
@ -648,7 +648,7 @@ def register_autograd(op: _op_identifier, setup_context_fn: Optional[Callable],
|
|||
>>> x, = ctx.saved_tensors
|
||||
>>> return grad * x.cos()
|
||||
>>>
|
||||
>>> torch.library.register_autograd("mylib::numpy_sin", setup_context, backward)
|
||||
>>> torch.library.register_autograd("mylib::numpy_sin", backward, setup_context=setup_context)
|
||||
>>>
|
||||
>>> x = torch.randn(3, requires_grad=True)
|
||||
>>> y = numpy_sin(x)
|
||||
|
|
@ -662,7 +662,7 @@ def register_autograd(op: _op_identifier, setup_context_fn: Optional[Callable],
|
|||
op = op._name
|
||||
opdef = _maybe_get_opdef(op)
|
||||
if opdef is not None:
|
||||
opdef.register_autograd(setup_context_fn, backward_fn)
|
||||
opdef.register_autograd(backward, setup_context=setup_context)
|
||||
return
|
||||
|
||||
assert isinstance(op, str)
|
||||
|
|
@ -676,7 +676,7 @@ def register_autograd(op: _op_identifier, setup_context_fn: Optional[Callable],
|
|||
f"a functional operator and register an autograd formula for that."
|
||||
)
|
||||
|
||||
info = _library.autograd.Info(setup_context_fn, backward_fn)
|
||||
info = _library.autograd.Info(backward, setup_context)
|
||||
autograd_kernel = _library.autograd.make_autograd_impl(op, info)
|
||||
namespace, opname = torch._library.utils.parse_namespace(qualname)
|
||||
if lib is None:
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ def numpy_cube_backward(ctx, grad_out, grad_dx):
|
|||
grad_x = numpy_mul(grad_out, dx) + 6 * numpy_mul(grad_dx, x)
|
||||
return grad_x
|
||||
|
||||
numpy_cube.register_autograd(numpy_cube_setup_context, numpy_cube_backward)
|
||||
numpy_cube.register_autograd(numpy_cube_backward, setup_context=numpy_cube_setup_context)
|
||||
|
||||
@torch.library.custom_op("_torch_testing::numpy_mul", mutates_args=())
|
||||
def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
|
||||
|
|
@ -66,7 +66,7 @@ def numpy_mul_backward(ctx, grad_out):
|
|||
grad_y = grad_out * x if ctx.needs_input_grad[1] else None
|
||||
return grad_x, grad_y
|
||||
|
||||
numpy_mul.register_autograd(numpy_mul_setup_context, numpy_mul_backward)
|
||||
numpy_mul.register_autograd(numpy_mul_backward, setup_context=numpy_mul_setup_context)
|
||||
|
||||
@torch.library.custom_op("_torch_testing::numpy_sort", mutates_args=())
|
||||
def numpy_sort(x: Tensor, dim: int) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
|
|
@ -95,7 +95,7 @@ def numpy_sort_backward(ctx, grad_out, grad_ind, grad_ind_inv):
|
|||
ind, ind_inv = ctx.saved_tensors
|
||||
return numpy_take(grad_out, ind_inv, ind, ctx.dim), None
|
||||
|
||||
numpy_sort.register_autograd(numpy_sort_setup_context, numpy_sort_backward)
|
||||
numpy_sort.register_autograd(numpy_sort_backward, setup_context=numpy_sort_setup_context)
|
||||
|
||||
|
||||
@torch.library.custom_op("_torch_testing::numpy_take", mutates_args=())
|
||||
|
|
@ -123,7 +123,7 @@ def numpy_take_backward(ctx, grad_out):
|
|||
grad_x = numpy_take(grad_out, ind_inv, ind, ctx.dim)
|
||||
return grad_x, None, None, None
|
||||
|
||||
numpy_take.register_autograd(numpy_take_setup_context, numpy_take_backward)
|
||||
numpy_take.register_autograd(numpy_take_backward, setup_context=numpy_take_setup_context)
|
||||
|
||||
@torch.library.custom_op("_torch_testing::numpy_nonzero", mutates_args=())
|
||||
def numpy_nonzero(x: Tensor) -> Tensor:
|
||||
|
|
@ -165,7 +165,7 @@ def numpy_view_copy_setup_context(ctx, inputs, output) -> None:
|
|||
def numpy_view_copy_backward(ctx, grad_out):
|
||||
return torch.ops._torch_testing.numpy_view_copy(grad_out, ctx.x_shape), None
|
||||
|
||||
numpy_view_copy.register_autograd(numpy_view_copy_setup_context, numpy_view_copy_backward)
|
||||
numpy_view_copy.register_autograd(numpy_view_copy_backward, setup_context=numpy_view_copy_setup_context)
|
||||
|
||||
def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
|
@ -188,7 +188,7 @@ def _(xs, dim):
|
|||
assert all(x.dtype == xs[0].dtype for x in xs)
|
||||
return torch.cat(xs, dim=dim)
|
||||
|
||||
def numpy_cat_setup_backward(ctx, inputs, output):
|
||||
def numpy_cat_setup_context(ctx, inputs, output):
|
||||
xs, dim = inputs
|
||||
ctx.dim_sizes = [x.shape[dim] for x in xs]
|
||||
ctx.dim = dim
|
||||
|
|
@ -201,7 +201,7 @@ def numpy_cat_backward(ctx, grad_out):
|
|||
grad_xs = torch.ops._torch_testing.numpy_split_copy(grad_out, splits, dim)
|
||||
return grad_xs, None
|
||||
|
||||
numpy_cat.register_autograd(numpy_cat_setup_backward, numpy_cat_backward)
|
||||
numpy_cat.register_autograd(numpy_cat_backward, setup_context=numpy_cat_setup_context)
|
||||
|
||||
def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
|
@ -228,7 +228,7 @@ def numpy_split_copy_backward(ctx, grad_out):
|
|||
result = torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim)
|
||||
return result, None, None
|
||||
|
||||
numpy_split_copy.register_autograd(numpy_split_copy_setup_context, numpy_split_copy_backward)
|
||||
numpy_split_copy.register_autograd(numpy_split_copy_backward, setup_context=numpy_split_copy_setup_context)
|
||||
|
||||
def sample_inputs_numpy_split_copy(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
|
@ -252,7 +252,9 @@ def numpy_split_copy_with_int_setup_context(ctx, inputs, output):
|
|||
def numpy_split_copy_with_int_backward(ctx, grad_out, _):
|
||||
return torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim), None, None
|
||||
|
||||
numpy_split_copy_with_int.register_autograd(numpy_split_copy_with_int_setup_context, numpy_split_copy_with_int_backward)
|
||||
numpy_split_copy_with_int.register_autograd(
|
||||
numpy_split_copy_with_int_backward,
|
||||
setup_context=numpy_split_copy_with_int_setup_context)
|
||||
|
||||
@torch.library.custom_op("_torch_testing::numpy_nms", mutates_args=())
|
||||
def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user