Extend impl_backward to be usable with torch.library operators (#106817)

- impl_save_for_backward/impl_backward only work for functional,
non-view schemas. We validate this.
- impl_save_for_backward/impl_backward raise if there already exists an
autograd implementation from torch.library / TORCH_LIBRARY.
- Operators constructed via custom_op receive an "autograd indirection
kernel". The "autograd indirection kernel" automatically pulls the
constructed autograd kernel out of a dict. When
impl_save_for_backward/impl_backward get used with torch.library
operators, we also register the "autograd indirection kernel" so we can
reuse the logic.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106817
Approved by: https://github.com/soulitzer
ghstack dependencies: #106799, #106800
This commit is contained in:
Richard Zou 2023-08-11 11:41:54 -07:00 committed by PyTorch MergeBot
parent db9a0cf689
commit 2932b0bf37
4 changed files with 161 additions and 12 deletions

View File

@ -1615,6 +1615,76 @@ def forward(self, x_1):
result = op(x) result = op(x)
self.assertEqual(result.shape, ()) self.assertEqual(result.shape, ())
def _test_backward_impl_raises(self, qualname, err_regex):
with self.assertRaisesRegex(RuntimeError, err_regex):
@custom_ops.impl_save_for_backward(qualname)
def foo2(x):
return
with self.assertRaisesRegex(RuntimeError, err_regex):
@custom_ops.impl_backward(qualname)
def foo3(x):
return
def test_backward_impl_on_existing_op_incorrect_schema_views(self):
lib = self.lib()
lib.define("foo(Tensor(a) x) -> Tensor(a)")
qualname = f"{self.test_ns}::foo"
self._test_backward_impl_raises(qualname, "operator that returns views")
def test_backward_impl_on_existing_op_incorrect_schema_mutable(self):
lib = self.lib()
lib.define("foo(Tensor(a!) x) -> Tensor")
qualname = f"{self.test_ns}::foo"
self._test_backward_impl_raises(qualname, "non-functional")
def test_backward_impl_on_existing_op_incorrect_schema_no_output(self):
lib = self.lib()
lib.define("foo(Tensor x) -> ()")
qualname = f"{self.test_ns}::foo"
self._test_backward_impl_raises(qualname, "no returns")
def test_backward_impl_on_existing_op_CompositeImplicitAutograd(self):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
qualname = f"{self.test_ns}::foo"
lib.impl("foo", lambda x: x.sin().cos(), "CompositeImplicitAutograd")
self._test_backward_impl_raises(qualname, "CompositeImplicitAutograd")
@parametrize("key", ["Autograd", "AutogradCPU", "AutogradCUDA"])
def test_backward_impl_on_existing_op_with_key(self, key):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
qualname = f"{self.test_ns}::foo"
lib.impl("foo", lambda x: x.sin().cos(), key)
self._test_backward_impl_raises(qualname, key)
def test_backward_impl_on_existing_op(self):
lib = self.lib()
lib.define("foo(Tensor x) -> Tensor")
qualname = f"{self.test_ns}::foo"
@custom_ops.impl(qualname)
def foo_impl(x):
with torch.no_grad():
return x.sin()
@custom_ops.impl_save_for_backward(qualname)
def foo_save_for_backward(inputs, output):
return inputs.x
@custom_ops.impl_backward(qualname)
def foo_backward(ctx, saved, grad_out):
return {"x": grad_out * saved.cos()}
op = self.get_op(qualname)
x = torch.randn([], requires_grad=True)
y = op(x)
(gx,) = torch.autograd.grad(y, x)
self.assertEqual(gx, x.cos())
only_for = ("cpu", "cuda") only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for) instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)

View File

@ -61,6 +61,7 @@ def mark_non_differentiable(ctx, output, output_differentiability):
# - Tensor # - Tensor
# - Tensor[] # - Tensor[]
# - int, bool, Scalar, float # - int, bool, Scalar, float
# See _check_can_register_backward
if output_differentiability is not None: if output_differentiability is not None:
if not isinstance(output, tuple): if not isinstance(output, tuple):
tuple_output = (output,) tuple_output = (output,)
@ -90,7 +91,8 @@ def mark_non_differentiable(ctx, output, output_differentiability):
def construct_autograd_kernel( def construct_autograd_kernel(
schema, schema,
output_differentiability, output_differentiability,
forward_op, custom_op,
op_overload,
save_for_backward_fn, save_for_backward_fn,
backward_fn): backward_fn):
@ -102,7 +104,7 @@ def construct_autograd_kernel(
ctx.set_materialize_grads(True) ctx.set_materialize_grads(True)
args = pytree.tree_unflatten(list(flat_args), spec) args = pytree.tree_unflatten(list(flat_args), spec)
with torch._C._AutoDispatchBelowAutograd(): with torch._C._AutoDispatchBelowAutograd():
output = forward_op(*args) output = op_overload(*args)
# We use the info about args to give better error messages in backward # We use the info about args to give better error messages in backward
args_info = namedtuple_args( args_info = namedtuple_args(
@ -131,11 +133,11 @@ def construct_autograd_kernel(
# Massage the grad_inputs_dict to a form acceptable by # Massage the grad_inputs_dict to a form acceptable by
# autograd.Function. # autograd.Function.
validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info) validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info) return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
generated_cls = gen_autograd_function( generated_cls = gen_autograd_function(
forward_op._opname + '_customop', forward, backward) custom_op._opname + '_customop', forward, backward)
flat_output = generated_cls.apply(*flat_args) flat_output = generated_cls.apply(*flat_args)
assert out_spec is not None assert out_spec is not None

View File

@ -6,7 +6,7 @@ import sys
import typing import typing
import weakref import weakref
from torchgen.model import FunctionSchema, OperatorName, SchemaKind from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy
import torch import torch
import torch._C as _C import torch._C as _C
@ -195,9 +195,16 @@ class CustomOp:
# NB: Some of these impls are registered as kernels to DispatchKeys. # NB: Some of these impls are registered as kernels to DispatchKeys.
# Modifying the _impls dict directly won't do anything in that case. # Modifying the _impls dict directly won't do anything in that case.
self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {} self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
# See NOTE [CustomOp autograd kernel indirection]
self._registered_autograd_kernel_indirection = False
global_registry[self._qualname] = self global_registry[self._qualname] = self
def _register_autograd_kernel_indirection(self):
assert not self._registered_autograd_kernel_indirection
self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd")
self._registered_autograd_kernel_indirection = True
# Records the impl and the source location in self._impls # Records the impl and the source location in self._impls
# Note that this doesn't cause torch.library to use the impl, that # Note that this doesn't cause torch.library to use the impl, that
# needs to be done in a separate self._lib.impl call. # needs to be done in a separate self._lib.impl call.
@ -429,6 +436,67 @@ class CustomOp:
return inner return inner
def _check_can_register_backward(self):
def error(detail):
raise RuntimeError(
f"Cannot use torch._custom_ops APIs to register backward "
f"formula for {detail}. Got operator "
f"{self._qualname} with schema: {schema}"
)
schema = self._schema
if schema.kind() != SchemaKind.functional:
error("non-functional operator")
rets = schema.returns
if not schema.returns:
error("operator with no returns")
assert len(rets) > 0
is_non_mutating_view = any(
r.annotation is not None and not r.annotation.is_write for r in rets
)
if is_non_mutating_view:
error("operator that returns views")
# We make assumptions about the schema's return types.
allowed_return_types = {
BaseType(BaseTy.int): "int",
BaseType(BaseTy.SymInt): "SymInt",
BaseType(BaseTy.bool): "bool",
BaseType(BaseTy.float): "float",
BaseType(BaseTy.Tensor): "Tensor",
ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
}
for ret in schema.returns:
if ret.type in allowed_return_types:
continue
error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})")
def _check_doesnt_have_library_autograd_impl(self):
if self._registered_autograd_kernel_indirection:
return
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
raise RuntimeError(
f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
f"already has an implementation for this device type via a "
f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
f"CompositeImplicitAutograd operators do not need an autograd formula; "
f"instead, the operator will decompose into its constituents and those "
f"can have autograd formulas defined on them.")
# We can improve this by adding "all Autograd<BACKEND> keys", but
# realistically people will just be using this API for CPU/CUDA for now.
for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
raise RuntimeError(
f"impl_backward/impl_save_for_backward: "
f"the operator {self._qualname} already has an Autograd kernel "
f"registered to DispatchKey::{key} vi a pre-existing "
f"torch.library or TORCH_LIBRARY registration. Please either "
f"remove those registrations or don't use the torch._custom_ops APIs")
def _check_doesnt_have_library_meta_impl(self): def _check_doesnt_have_library_meta_impl(self):
if self._has_impl("abstract"): if self._has_impl("abstract"):
return return
@ -477,6 +545,7 @@ class CustomOp:
self._schema, self._schema,
self._output_differentiability, self._output_differentiability,
self, self,
get_op(self._qualname),
self._get_impl("save_for_backward").func, self._get_impl("save_for_backward").func,
self._get_impl("backward").func) self._get_impl("backward").func)
self._register_impl("autograd", kernel) self._register_impl("autograd", kernel)
@ -487,6 +556,10 @@ class CustomOp:
Please see impl_backward for more details. Please see impl_backward for more details.
""" """
def inner(f): def inner(f):
self._check_can_register_backward()
self._check_doesnt_have_library_autograd_impl()
if not self._registered_autograd_kernel_indirection:
self._register_autograd_kernel_indirection()
self._register_impl("save_for_backward", f, stacklevel=_stacklevel) self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
if self._has_impl("backward"): if self._has_impl("backward"):
self._register_autograd_kernel() self._register_autograd_kernel()
@ -546,6 +619,10 @@ class CustomOp:
yell() yell()
def inner(f): def inner(f):
self._check_can_register_backward()
self._check_doesnt_have_library_autograd_impl()
if not self._registered_autograd_kernel_indirection:
self._register_autograd_kernel_indirection()
self._register_impl("backward", f, stacklevel=_stacklevel) self._register_impl("backward", f, stacklevel=_stacklevel)
self._output_differentiability = output_differentiability self._output_differentiability = output_differentiability
if self._has_impl("save_for_backward"): if self._has_impl("save_for_backward"):
@ -963,7 +1040,10 @@ def custom_op_from_existing(op):
ns = op.namespace ns = op.namespace
lib = torch.library.Library(ns, "FRAGMENT") lib = torch.library.Library(ns, "FRAGMENT")
name = op.name().split("::")[-1] name = op.name().split("::")[-1]
schema = FunctionSchema.parse(str(op._schema)) schema_str = str(op._schema)
# CustomOp expects the schema string without the namespace
schema_str = schema_str.split("::")[-1]
schema = FunctionSchema.parse(schema_str)
return CustomOp(lib, ns, schema, name, op, _private_access=True) return CustomOp(lib, ns, schema, name, op, _private_access=True)
@ -1008,10 +1088,7 @@ def _custom_op_with_schema(qualname, schema):
lib.define(schema_str) lib.define(schema_str)
ophandle = find_ophandle_or_throw(ns, function_schema.name) ophandle = find_ophandle_or_throw(ns, function_schema.name)
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True) result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
result._register_autograd_kernel_indirection()
library.impl(lib, result._opname, "Autograd")(
autograd_kernel_indirection(weakref.proxy(result))
)
torch._C._dispatch_set_report_error_callback( torch._C._dispatch_set_report_error_callback(
ophandle, functools.partial(report_error_callback, weakref.proxy(result)) ophandle, functools.partial(report_error_callback, weakref.proxy(result))

View File

@ -266,7 +266,7 @@ def impl_save_for_backward(qualname, *, func=None):
""" """
def inner(func): def inner(func):
custom_op = _find_custom_op(qualname) custom_op = _find_custom_op(qualname, also_check_torch_library=True)
custom_op.impl_save_for_backward(_stacklevel=3)(func) custom_op.impl_save_for_backward(_stacklevel=3)(func)
return func return func
@ -313,7 +313,7 @@ def impl_backward(qualname, output_differentiability=None, *, func=None):
""" """
def inner(func): def inner(func):
custom_op = _find_custom_op(qualname) custom_op = _find_custom_op(qualname, also_check_torch_library=True)
custom_op.impl_backward(output_differentiability, _stacklevel=3)(func) custom_op.impl_backward(output_differentiability, _stacklevel=3)(func)
return func return func