Extend impl_backward to be usable with torch.library operators (#106817)

- impl_save_for_backward/impl_backward only work for functional,
non-view schemas. We validate this.
- impl_save_for_backward/impl_backward raise if there already exists an
autograd implementation from torch.library / TORCH_LIBRARY.
- Operators constructed via custom_op receive an "autograd indirection
kernel". The "autograd indirection kernel" automatically pulls the
constructed autograd kernel out of a dict. When
impl_save_for_backward/impl_backward get used with torch.library
operators, we also register the "autograd indirection kernel" so we can
reuse the logic.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106817
Approved by: https://github.com/soulitzer
ghstack dependencies: #106799, #106800
This commit is contained in:
Richard Zou 2023-08-11 11:41:54 -07:00 committed by PyTorch MergeBot
parent db9a0cf689
commit 2932b0bf37
4 changed files with 161 additions and 12 deletions

View File

@ -1615,6 +1615,76 @@ def forward(self, x_1):
result = op(x)
self.assertEqual(result.shape, ())
def _test_backward_impl_raises(self, qualname, err_regex):
with self.assertRaisesRegex(RuntimeError, err_regex):
@custom_ops.impl_save_for_backward(qualname)
def foo2(x):
return
with self.assertRaisesRegex(RuntimeError, err_regex):
@custom_ops.impl_backward(qualname)
def foo3(x):
return
def test_backward_impl_on_existing_op_incorrect_schema_views(self):
lib = self.lib()
lib.define("foo(Tensor(a) x) -> Tensor(a)")
qualname = f"{self.test_ns}::foo"
self._test_backward_impl_raises(qualname, "operator that returns views")
def test_backward_impl_on_existing_op_incorrect_schema_mutable(self):
lib = self.lib()
lib.define("foo(Tensor(a!) x) -> Tensor")
qualname = f"{self.test_ns}::foo"
self._test_backward_impl_raises(qualname, "non-functional")
def test_backward_impl_on_existing_op_incorrect_schema_no_output(self):
lib = self.lib()
lib.define("foo(Tensor x) -> ()")
qualname = f"{self.test_ns}::foo"
self._test_backward_impl_raises(qualname, "no returns")
def test_backward_impl_on_existing_op_CompositeImplicitAutograd(self):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
qualname = f"{self.test_ns}::foo"
lib.impl("foo", lambda x: x.sin().cos(), "CompositeImplicitAutograd")
self._test_backward_impl_raises(qualname, "CompositeImplicitAutograd")
@parametrize("key", ["Autograd", "AutogradCPU", "AutogradCUDA"])
def test_backward_impl_on_existing_op_with_key(self, key):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
qualname = f"{self.test_ns}::foo"
lib.impl("foo", lambda x: x.sin().cos(), key)
self._test_backward_impl_raises(qualname, key)
def test_backward_impl_on_existing_op(self):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
qualname = f"{self.test_ns}::foo"
@custom_ops.impl(qualname)
def foo_impl(x):
with torch.no_grad():
return x.sin()
@custom_ops.impl_save_for_backward(qualname)
def foo_save_for_backward(inputs, output):
return inputs.x
@custom_ops.impl_backward(qualname)
def foo_backward(ctx, saved, grad_out):
return {"x": grad_out * saved.cos()}
op = self.get_op(qualname)
x = torch.randn([], requires_grad=True)
y = op(x)
(gx,) = torch.autograd.grad(y, x)
self.assertEqual(gx, x.cos())
only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)

View File

@ -61,6 +61,7 @@ def mark_non_differentiable(ctx, output, output_differentiability):
# - Tensor
# - Tensor[]
# - int, bool, Scalar, float
# See _check_can_register_backward
if output_differentiability is not None:
if not isinstance(output, tuple):
tuple_output = (output,)
@ -90,7 +91,8 @@ def mark_non_differentiable(ctx, output, output_differentiability):
def construct_autograd_kernel(
schema,
output_differentiability,
forward_op,
custom_op,
op_overload,
save_for_backward_fn,
backward_fn):
@ -102,7 +104,7 @@ def construct_autograd_kernel(
ctx.set_materialize_grads(True)
args = pytree.tree_unflatten(list(flat_args), spec)
with torch._C._AutoDispatchBelowAutograd():
output = forward_op(*args)
output = op_overload(*args)
# We use the info about args to give better error messages in backward
args_info = namedtuple_args(
@ -131,11 +133,11 @@ def construct_autograd_kernel(
# Massage the grad_inputs_dict to a form acceptable by
# autograd.Function.
validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info)
validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
generated_cls = gen_autograd_function(
forward_op._opname + '_customop', forward, backward)
custom_op._opname + '_customop', forward, backward)
flat_output = generated_cls.apply(*flat_args)
assert out_spec is not None

View File

@ -6,7 +6,7 @@ import sys
import typing
import weakref
from torchgen.model import FunctionSchema, OperatorName, SchemaKind
from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy
import torch
import torch._C as _C
@ -195,9 +195,16 @@ class CustomOp:
# NB: Some of these impls are registered as kernels to DispatchKeys.
# Modifying the _impls dict directly won't do anything in that case.
self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
# See NOTE [CustomOp autograd kernel indirection]
self._registered_autograd_kernel_indirection = False
global_registry[self._qualname] = self
def _register_autograd_kernel_indirection(self):
assert not self._registered_autograd_kernel_indirection
self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd")
self._registered_autograd_kernel_indirection = True
# Records the impl and the source location in self._impls
# Note that this doesn't cause torch.library to use the impl, that
# needs to be done in a separate self._lib.impl call.
@ -429,6 +436,67 @@ class CustomOp:
return inner
def _check_can_register_backward(self):
def error(detail):
raise RuntimeError(
f"Cannot use torch._custom_ops APIs to register backward "
f"formula for {detail}. Got operator "
f"{self._qualname} with schema: {schema}"
)
schema = self._schema
if schema.kind() != SchemaKind.functional:
error("non-functional operator")
rets = schema.returns
if not schema.returns:
error("operator with no returns")
assert len(rets) > 0
is_non_mutating_view = any(
r.annotation is not None and not r.annotation.is_write for r in rets
)
if is_non_mutating_view:
error("operator that returns views")
# We make assumptions about the schema's return types.
allowed_return_types = {
BaseType(BaseTy.int): "int",
BaseType(BaseTy.SymInt): "SymInt",
BaseType(BaseTy.bool): "bool",
BaseType(BaseTy.float): "float",
BaseType(BaseTy.Tensor): "Tensor",
ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
}
for ret in schema.returns:
if ret.type in allowed_return_types:
continue
error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})")
def _check_doesnt_have_library_autograd_impl(self):
if self._registered_autograd_kernel_indirection:
return
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
raise RuntimeError(
f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
f"already has an implementation for this device type via a "
f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
f"CompositeImplicitAutograd operators do not need an autograd formula; "
f"instead, the operator will decompose into its constituents and those "
f"can have autograd formulas defined on them.")
# We can improve this by adding "all Autograd<BACKEND> keys", but
# realistically people will just be using this API for CPU/CUDA for now.
for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
raise RuntimeError(
f"impl_backward/impl_save_for_backward: "
f"the operator {self._qualname} already has an Autograd kernel "
f"registered to DispatchKey::{key} vi a pre-existing "
f"torch.library or TORCH_LIBRARY registration. Please either "
f"remove those registrations or don't use the torch._custom_ops APIs")
def _check_doesnt_have_library_meta_impl(self):
if self._has_impl("abstract"):
return
@ -477,6 +545,7 @@ class CustomOp:
self._schema,
self._output_differentiability,
self,
get_op(self._qualname),
self._get_impl("save_for_backward").func,
self._get_impl("backward").func)
self._register_impl("autograd", kernel)
@ -487,6 +556,10 @@ class CustomOp:
Please see impl_backward for more details.
"""
def inner(f):
self._check_can_register_backward()
self._check_doesnt_have_library_autograd_impl()
if not self._registered_autograd_kernel_indirection:
self._register_autograd_kernel_indirection()
self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
if self._has_impl("backward"):
self._register_autograd_kernel()
@ -546,6 +619,10 @@ class CustomOp:
yell()
def inner(f):
self._check_can_register_backward()
self._check_doesnt_have_library_autograd_impl()
if not self._registered_autograd_kernel_indirection:
self._register_autograd_kernel_indirection()
self._register_impl("backward", f, stacklevel=_stacklevel)
self._output_differentiability = output_differentiability
if self._has_impl("save_for_backward"):
@ -963,7 +1040,10 @@ def custom_op_from_existing(op):
ns = op.namespace
lib = torch.library.Library(ns, "FRAGMENT")
name = op.name().split("::")[-1]
schema = FunctionSchema.parse(str(op._schema))
schema_str = str(op._schema)
# CustomOp expects the schema string without the namespace
schema_str = schema_str.split("::")[-1]
schema = FunctionSchema.parse(schema_str)
return CustomOp(lib, ns, schema, name, op, _private_access=True)
@ -1008,10 +1088,7 @@ def _custom_op_with_schema(qualname, schema):
lib.define(schema_str)
ophandle = find_ophandle_or_throw(ns, function_schema.name)
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
library.impl(lib, result._opname, "Autograd")(
autograd_kernel_indirection(weakref.proxy(result))
)
result._register_autograd_kernel_indirection()
torch._C._dispatch_set_report_error_callback(
ophandle, functools.partial(report_error_callback, weakref.proxy(result))

View File

@ -266,7 +266,7 @@ def impl_save_for_backward(qualname, *, func=None):
"""
def inner(func):
custom_op = _find_custom_op(qualname)
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
custom_op.impl_save_for_backward(_stacklevel=3)(func)
return func
@ -313,7 +313,7 @@ def impl_backward(qualname, output_differentiability=None, *, func=None):
"""
def inner(func):
custom_op = _find_custom_op(qualname)
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
custom_op.impl_backward(output_differentiability, _stacklevel=3)(func)
return func