mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Extend impl_backward to be usable with torch.library operators (#106817)
- impl_save_for_backward/impl_backward only work for functional, non-view schemas. We validate this. - impl_save_for_backward/impl_backward raise if there already exists an autograd implementation from torch.library / TORCH_LIBRARY. - Operators constructed via custom_op receive an "autograd indirection kernel". The "autograd indirection kernel" automatically pulls the constructed autograd kernel out of a dict. When impl_save_for_backward/impl_backward get used with torch.library operators, we also register the "autograd indirection kernel" so we can reuse the logic. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/106817 Approved by: https://github.com/soulitzer ghstack dependencies: #106799, #106800
This commit is contained in:
parent
db9a0cf689
commit
2932b0bf37
|
|
@ -1615,6 +1615,76 @@ def forward(self, x_1):
|
|||
result = op(x)
|
||||
self.assertEqual(result.shape, ())
|
||||
|
||||
def _test_backward_impl_raises(self, qualname, err_regex):
|
||||
with self.assertRaisesRegex(RuntimeError, err_regex):
|
||||
|
||||
@custom_ops.impl_save_for_backward(qualname)
|
||||
def foo2(x):
|
||||
return
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, err_regex):
|
||||
|
||||
@custom_ops.impl_backward(qualname)
|
||||
def foo3(x):
|
||||
return
|
||||
|
||||
def test_backward_impl_on_existing_op_incorrect_schema_views(self):
|
||||
lib = self.lib()
|
||||
lib.define("foo(Tensor(a) x) -> Tensor(a)")
|
||||
qualname = f"{self.test_ns}::foo"
|
||||
self._test_backward_impl_raises(qualname, "operator that returns views")
|
||||
|
||||
def test_backward_impl_on_existing_op_incorrect_schema_mutable(self):
|
||||
lib = self.lib()
|
||||
lib.define("foo(Tensor(a!) x) -> Tensor")
|
||||
qualname = f"{self.test_ns}::foo"
|
||||
self._test_backward_impl_raises(qualname, "non-functional")
|
||||
|
||||
def test_backward_impl_on_existing_op_incorrect_schema_no_output(self):
|
||||
lib = self.lib()
|
||||
lib.define("foo(Tensor x) -> ()")
|
||||
qualname = f"{self.test_ns}::foo"
|
||||
self._test_backward_impl_raises(qualname, "no returns")
|
||||
|
||||
def test_backward_impl_on_existing_op_CompositeImplicitAutograd(self):
|
||||
lib = self.lib()
|
||||
lib.define("foo(Tensor x) -> Tensor")
|
||||
qualname = f"{self.test_ns}::foo"
|
||||
lib.impl("foo", lambda x: x.sin().cos(), "CompositeImplicitAutograd")
|
||||
self._test_backward_impl_raises(qualname, "CompositeImplicitAutograd")
|
||||
|
||||
@parametrize("key", ["Autograd", "AutogradCPU", "AutogradCUDA"])
|
||||
def test_backward_impl_on_existing_op_with_key(self, key):
|
||||
lib = self.lib()
|
||||
lib.define("foo(Tensor x) -> Tensor")
|
||||
qualname = f"{self.test_ns}::foo"
|
||||
lib.impl("foo", lambda x: x.sin().cos(), key)
|
||||
self._test_backward_impl_raises(qualname, key)
|
||||
|
||||
def test_backward_impl_on_existing_op(self):
|
||||
lib = self.lib()
|
||||
lib.define("foo(Tensor x) -> Tensor")
|
||||
qualname = f"{self.test_ns}::foo"
|
||||
|
||||
@custom_ops.impl(qualname)
|
||||
def foo_impl(x):
|
||||
with torch.no_grad():
|
||||
return x.sin()
|
||||
|
||||
@custom_ops.impl_save_for_backward(qualname)
|
||||
def foo_save_for_backward(inputs, output):
|
||||
return inputs.x
|
||||
|
||||
@custom_ops.impl_backward(qualname)
|
||||
def foo_backward(ctx, saved, grad_out):
|
||||
return {"x": grad_out * saved.cos()}
|
||||
|
||||
op = self.get_op(qualname)
|
||||
x = torch.randn([], requires_grad=True)
|
||||
y = op(x)
|
||||
(gx,) = torch.autograd.grad(y, x)
|
||||
self.assertEqual(gx, x.cos())
|
||||
|
||||
|
||||
only_for = ("cpu", "cuda")
|
||||
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ def mark_non_differentiable(ctx, output, output_differentiability):
|
|||
# - Tensor
|
||||
# - Tensor[]
|
||||
# - int, bool, Scalar, float
|
||||
# See _check_can_register_backward
|
||||
if output_differentiability is not None:
|
||||
if not isinstance(output, tuple):
|
||||
tuple_output = (output,)
|
||||
|
|
@ -90,7 +91,8 @@ def mark_non_differentiable(ctx, output, output_differentiability):
|
|||
def construct_autograd_kernel(
|
||||
schema,
|
||||
output_differentiability,
|
||||
forward_op,
|
||||
custom_op,
|
||||
op_overload,
|
||||
save_for_backward_fn,
|
||||
backward_fn):
|
||||
|
||||
|
|
@ -102,7 +104,7 @@ def construct_autograd_kernel(
|
|||
ctx.set_materialize_grads(True)
|
||||
args = pytree.tree_unflatten(list(flat_args), spec)
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
output = forward_op(*args)
|
||||
output = op_overload(*args)
|
||||
|
||||
# We use the info about args to give better error messages in backward
|
||||
args_info = namedtuple_args(
|
||||
|
|
@ -131,11 +133,11 @@ def construct_autograd_kernel(
|
|||
|
||||
# Massage the grad_inputs_dict to a form acceptable by
|
||||
# autograd.Function.
|
||||
validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info)
|
||||
validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
|
||||
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
|
||||
|
||||
generated_cls = gen_autograd_function(
|
||||
forward_op._opname + '_customop', forward, backward)
|
||||
custom_op._opname + '_customop', forward, backward)
|
||||
|
||||
flat_output = generated_cls.apply(*flat_args)
|
||||
assert out_spec is not None
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import sys
|
|||
import typing
|
||||
import weakref
|
||||
|
||||
from torchgen.model import FunctionSchema, OperatorName, SchemaKind
|
||||
from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy
|
||||
|
||||
import torch
|
||||
import torch._C as _C
|
||||
|
|
@ -195,9 +195,16 @@ class CustomOp:
|
|||
# NB: Some of these impls are registered as kernels to DispatchKeys.
|
||||
# Modifying the _impls dict directly won't do anything in that case.
|
||||
self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
|
||||
# See NOTE [CustomOp autograd kernel indirection]
|
||||
self._registered_autograd_kernel_indirection = False
|
||||
|
||||
global_registry[self._qualname] = self
|
||||
|
||||
def _register_autograd_kernel_indirection(self):
|
||||
assert not self._registered_autograd_kernel_indirection
|
||||
self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd")
|
||||
self._registered_autograd_kernel_indirection = True
|
||||
|
||||
# Records the impl and the source location in self._impls
|
||||
# Note that this doesn't cause torch.library to use the impl, that
|
||||
# needs to be done in a separate self._lib.impl call.
|
||||
|
|
@ -429,6 +436,67 @@ class CustomOp:
|
|||
|
||||
return inner
|
||||
|
||||
def _check_can_register_backward(self):
|
||||
def error(detail):
|
||||
raise RuntimeError(
|
||||
f"Cannot use torch._custom_ops APIs to register backward "
|
||||
f"formula for {detail}. Got operator "
|
||||
f"{self._qualname} with schema: {schema}"
|
||||
)
|
||||
|
||||
schema = self._schema
|
||||
if schema.kind() != SchemaKind.functional:
|
||||
error("non-functional operator")
|
||||
|
||||
rets = schema.returns
|
||||
if not schema.returns:
|
||||
error("operator with no returns")
|
||||
|
||||
assert len(rets) > 0
|
||||
is_non_mutating_view = any(
|
||||
r.annotation is not None and not r.annotation.is_write for r in rets
|
||||
)
|
||||
if is_non_mutating_view:
|
||||
error("operator that returns views")
|
||||
|
||||
# We make assumptions about the schema's return types.
|
||||
allowed_return_types = {
|
||||
BaseType(BaseTy.int): "int",
|
||||
BaseType(BaseTy.SymInt): "SymInt",
|
||||
BaseType(BaseTy.bool): "bool",
|
||||
BaseType(BaseTy.float): "float",
|
||||
BaseType(BaseTy.Tensor): "Tensor",
|
||||
ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
|
||||
}
|
||||
for ret in schema.returns:
|
||||
if ret.type in allowed_return_types:
|
||||
continue
|
||||
error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})")
|
||||
|
||||
def _check_doesnt_have_library_autograd_impl(self):
|
||||
if self._registered_autograd_kernel_indirection:
|
||||
return
|
||||
|
||||
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
|
||||
raise RuntimeError(
|
||||
f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
|
||||
f"already has an implementation for this device type via a "
|
||||
f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
|
||||
f"CompositeImplicitAutograd operators do not need an autograd formula; "
|
||||
f"instead, the operator will decompose into its constituents and those "
|
||||
f"can have autograd formulas defined on them.")
|
||||
|
||||
# We can improve this by adding "all Autograd<BACKEND> keys", but
|
||||
# realistically people will just be using this API for CPU/CUDA for now.
|
||||
for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
|
||||
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
|
||||
raise RuntimeError(
|
||||
f"impl_backward/impl_save_for_backward: "
|
||||
f"the operator {self._qualname} already has an Autograd kernel "
|
||||
f"registered to DispatchKey::{key} vi a pre-existing "
|
||||
f"torch.library or TORCH_LIBRARY registration. Please either "
|
||||
f"remove those registrations or don't use the torch._custom_ops APIs")
|
||||
|
||||
def _check_doesnt_have_library_meta_impl(self):
|
||||
if self._has_impl("abstract"):
|
||||
return
|
||||
|
|
@ -477,6 +545,7 @@ class CustomOp:
|
|||
self._schema,
|
||||
self._output_differentiability,
|
||||
self,
|
||||
get_op(self._qualname),
|
||||
self._get_impl("save_for_backward").func,
|
||||
self._get_impl("backward").func)
|
||||
self._register_impl("autograd", kernel)
|
||||
|
|
@ -487,6 +556,10 @@ class CustomOp:
|
|||
Please see impl_backward for more details.
|
||||
"""
|
||||
def inner(f):
|
||||
self._check_can_register_backward()
|
||||
self._check_doesnt_have_library_autograd_impl()
|
||||
if not self._registered_autograd_kernel_indirection:
|
||||
self._register_autograd_kernel_indirection()
|
||||
self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
|
||||
if self._has_impl("backward"):
|
||||
self._register_autograd_kernel()
|
||||
|
|
@ -546,6 +619,10 @@ class CustomOp:
|
|||
yell()
|
||||
|
||||
def inner(f):
|
||||
self._check_can_register_backward()
|
||||
self._check_doesnt_have_library_autograd_impl()
|
||||
if not self._registered_autograd_kernel_indirection:
|
||||
self._register_autograd_kernel_indirection()
|
||||
self._register_impl("backward", f, stacklevel=_stacklevel)
|
||||
self._output_differentiability = output_differentiability
|
||||
if self._has_impl("save_for_backward"):
|
||||
|
|
@ -963,7 +1040,10 @@ def custom_op_from_existing(op):
|
|||
ns = op.namespace
|
||||
lib = torch.library.Library(ns, "FRAGMENT")
|
||||
name = op.name().split("::")[-1]
|
||||
schema = FunctionSchema.parse(str(op._schema))
|
||||
schema_str = str(op._schema)
|
||||
# CustomOp expects the schema string without the namespace
|
||||
schema_str = schema_str.split("::")[-1]
|
||||
schema = FunctionSchema.parse(schema_str)
|
||||
return CustomOp(lib, ns, schema, name, op, _private_access=True)
|
||||
|
||||
|
||||
|
|
@ -1008,10 +1088,7 @@ def _custom_op_with_schema(qualname, schema):
|
|||
lib.define(schema_str)
|
||||
ophandle = find_ophandle_or_throw(ns, function_schema.name)
|
||||
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
|
||||
|
||||
library.impl(lib, result._opname, "Autograd")(
|
||||
autograd_kernel_indirection(weakref.proxy(result))
|
||||
)
|
||||
result._register_autograd_kernel_indirection()
|
||||
|
||||
torch._C._dispatch_set_report_error_callback(
|
||||
ophandle, functools.partial(report_error_callback, weakref.proxy(result))
|
||||
|
|
|
|||
|
|
@ -266,7 +266,7 @@ def impl_save_for_backward(qualname, *, func=None):
|
|||
"""
|
||||
|
||||
def inner(func):
|
||||
custom_op = _find_custom_op(qualname)
|
||||
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
|
||||
custom_op.impl_save_for_backward(_stacklevel=3)(func)
|
||||
return func
|
||||
|
||||
|
|
@ -313,7 +313,7 @@ def impl_backward(qualname, output_differentiability=None, *, func=None):
|
|||
"""
|
||||
|
||||
def inner(func):
|
||||
custom_op = _find_custom_op(qualname)
|
||||
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
|
||||
custom_op.impl_backward(output_differentiability, _stacklevel=3)(func)
|
||||
return func
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user