mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[reland] Deprecate registering autograd kernels at not an autograd key (#105078)
Summary:
Context
-------
This PR adds a new fallback to the Autograd dispatch keys.
If you would prefer the old behavior:
- A quick (unsupported) way to get the previous behavior is to call
`torch._C._set_autograd_fallback("nothing")`
- Register "torch::CppFunction::makeFallthrough()" to your Autograd key,
like in https://gist.github.com/zou3519/d09a5f4b1afe2430af09fea67c6ff2c8
It is possible that this PR regresses performance of overhead-bound
models. If this is the case, please reach out (and apply one of the
temporary fixes in the previous section).
Description for reviewers
-------------------------
In order to deprecate registering autograd kernels at not an autograd
key, we add a fallback to the Autograd dispatch keys. This fallback
raises a warning if the user attempts to backprop through the operator
and is also configurable to either warn or not warn.
The goal of this PR is to
- preserve as much BC as possible
- raise a warning that whatever the user is doing is potentially wrong.
- be as performant as possible
There are roughly two cases:
- if the post-autograd kernels return a Tensor that requires grad, then
we install an autograd hook that raises a warning. We are preserving BC
in that it is possible that the user has a torch::autograd::Function
registered to their CPU key.
- if the post-autograd kernels return Tensors that do not require grad,
then we make them require_grad and install a WarnNotImplemented grad fn
that warns in the backward pass. This is mildy BC-breaking (see next
section).
Test Plan:
- bunch of new tests
BC-Breaking Note
----------------
This PR adds a new fallback to the Autograd dispatch keys. It affects
custom operators that do not have a kernel registered to the Autograd
keys (e.g. AutogradCPU and AutogradCUDA).
If the previous behavior was that the custom operator would return
Tensors that do not require grad if the inputs do require grad, then
this PR changes it so that all floating-point and complex returns do
require grad. See the "Context" section above for how to get the old
behavior.
Differential Revision: D47408353
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105078
Approved by: https://github.com/soulitzer
This commit is contained in:
parent
b4d91b1c5b
commit
f03a8f0589
|
|
@ -268,6 +268,7 @@ test_dynamo_shard() {
|
||||||
test_package \
|
test_package \
|
||||||
test_legacy_vmap \
|
test_legacy_vmap \
|
||||||
test_custom_op_testing \
|
test_custom_op_testing \
|
||||||
|
test_content_store \
|
||||||
export/test_db \
|
export/test_db \
|
||||||
functorch/test_dims \
|
functorch/test_dims \
|
||||||
functorch/test_aotdispatch \
|
functorch/test_aotdispatch \
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
#include <ATen/core/dispatch/Dispatcher.h>
|
|
||||||
#include <ATen/core/LegacyTypeDispatch.h>
|
#include <ATen/core/LegacyTypeDispatch.h>
|
||||||
|
#include <ATen/core/dispatch/Dispatcher.h>
|
||||||
|
#include <ATen/core/VariableHooksInterface.h>
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
@ -27,36 +28,43 @@ namespace {
|
||||||
// NB: But not the private use ones; maybe the extension wants
|
// NB: But not the private use ones; maybe the extension wants
|
||||||
// to override it themselves!
|
// to override it themselves!
|
||||||
|
|
||||||
|
void autograd_fallback(
|
||||||
|
const c10::OperatorHandle& op,
|
||||||
|
c10::DispatchKeySet dispatch_keys,
|
||||||
|
torch::jit::Stack* stack);
|
||||||
|
|
||||||
|
#define AUTOGRAD_FALLBACK torch::CppFunction::makeFromBoxedFunction<&autograd_fallback>()
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(_, AutogradOther, m) {
|
TORCH_LIBRARY_IMPL(_, AutogradOther, m) {
|
||||||
m.fallback(torch::CppFunction::makeFallthrough());
|
m.fallback(AUTOGRAD_FALLBACK);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(_, AutogradCPU, m) {
|
TORCH_LIBRARY_IMPL(_, AutogradCPU, m) {
|
||||||
m.fallback(torch::CppFunction::makeFallthrough());
|
m.fallback(AUTOGRAD_FALLBACK);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(_, AutogradXPU, m) {
|
TORCH_LIBRARY_IMPL(_, AutogradXPU, m) {
|
||||||
m.fallback(torch::CppFunction::makeFallthrough());
|
m.fallback(AUTOGRAD_FALLBACK);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(_, AutogradCUDA, m) {
|
TORCH_LIBRARY_IMPL(_, AutogradCUDA, m) {
|
||||||
m.fallback(torch::CppFunction::makeFallthrough());
|
m.fallback(AUTOGRAD_FALLBACK);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(_, AutogradXLA, m) {
|
TORCH_LIBRARY_IMPL(_, AutogradXLA, m) {
|
||||||
m.fallback(torch::CppFunction::makeFallthrough());
|
m.fallback(AUTOGRAD_FALLBACK);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(_, AutogradLazy, m) {
|
TORCH_LIBRARY_IMPL(_, AutogradLazy, m) {
|
||||||
m.fallback(torch::CppFunction::makeFallthrough());
|
m.fallback(AUTOGRAD_FALLBACK);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(_, AutogradMPS, m) {
|
TORCH_LIBRARY_IMPL(_, AutogradMPS, m) {
|
||||||
m.fallback(torch::CppFunction::makeFallthrough());
|
m.fallback(AUTOGRAD_FALLBACK);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(_, AutogradMeta, m) {
|
TORCH_LIBRARY_IMPL(_, AutogradMeta, m) {
|
||||||
m.fallback(torch::CppFunction::makeFallthrough());
|
m.fallback(AUTOGRAD_FALLBACK);
|
||||||
}
|
}
|
||||||
|
|
||||||
// see Note [ADInplaceOrView key]
|
// see Note [ADInplaceOrView key]
|
||||||
|
|
@ -65,7 +73,24 @@ TORCH_LIBRARY_IMPL(_, ADInplaceOrView, m) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
|
TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
|
||||||
m.fallback(torch::CppFunction::makeFallthrough());
|
m.fallback(AUTOGRAD_FALLBACK);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#undef AUTOGRAD_FALLBACK
|
||||||
|
|
||||||
|
void autograd_fallback(
|
||||||
|
const c10::OperatorHandle& op,
|
||||||
|
c10::DispatchKeySet dispatch_keys,
|
||||||
|
torch::jit::Stack* stack) {
|
||||||
|
// PyTorch has separate builds, some of which don't include autograd.
|
||||||
|
// So we define some behavior for when autograd isn't included and
|
||||||
|
// go through a layer of indirection (VariableHooksInterface) when it is.
|
||||||
|
// See aten/src/ATen/core/VariableHooksInterface.h for more details.
|
||||||
|
if (!at::impl::HasVariableHooks()) {
|
||||||
|
op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
at::impl::GetVariableHooks()->basic_autograd_not_implemented_fallback(op, dispatch_keys, stack);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
|
||||||
|
|
@ -13,5 +13,8 @@ VariableHooksInterface* GetVariableHooks() {
|
||||||
TORCH_CHECK(hooks, "Support for autograd has not been loaded; have you linked against libtorch.so?")
|
TORCH_CHECK(hooks, "Support for autograd has not been loaded; have you linked against libtorch.so?")
|
||||||
return hooks;
|
return hooks;
|
||||||
}
|
}
|
||||||
|
bool HasVariableHooks() {
|
||||||
|
return hooks != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
}} // namespace at::impl
|
}} // namespace at::impl
|
||||||
|
|
|
||||||
|
|
@ -59,10 +59,12 @@ struct TORCH_API VariableHooksInterface {
|
||||||
virtual bool retains_grad(const TensorBase&) const = 0;
|
virtual bool retains_grad(const TensorBase&) const = 0;
|
||||||
virtual void _backward(const Tensor&, TensorList, const c10::optional<Tensor>&, c10::optional<bool>, bool) const = 0;
|
virtual void _backward(const Tensor&, TensorList, const c10::optional<Tensor>&, c10::optional<bool>, bool) const = 0;
|
||||||
virtual void requires_grad_(const TensorBase&, bool) const = 0;
|
virtual void requires_grad_(const TensorBase&, bool) const = 0;
|
||||||
|
virtual void basic_autograd_not_implemented_fallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
TORCH_API void SetVariableHooks(VariableHooksInterface* hooks);
|
TORCH_API void SetVariableHooks(VariableHooksInterface* hooks);
|
||||||
TORCH_API VariableHooksInterface* GetVariableHooks();
|
TORCH_API VariableHooksInterface* GetVariableHooks();
|
||||||
|
TORCH_API bool HasVariableHooks();
|
||||||
|
|
||||||
struct TORCH_API VariableHooksRegisterer {
|
struct TORCH_API VariableHooksRegisterer {
|
||||||
explicit VariableHooksRegisterer(VariableHooksInterface* hooks) {
|
explicit VariableHooksRegisterer(VariableHooksInterface* hooks) {
|
||||||
|
|
|
||||||
372
test/autograd/test_fallback.py
Normal file
372
test/autograd/test_fallback.py
Normal file
|
|
@ -0,0 +1,372 @@
|
||||||
|
# Owner(s): ["module: autograd"]
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.library import Library
|
||||||
|
from torch.testing._internal.common_utils import (
|
||||||
|
TestCase,
|
||||||
|
parametrize,
|
||||||
|
instantiate_parametrized_tests,
|
||||||
|
run_tests,
|
||||||
|
)
|
||||||
|
import contextlib
|
||||||
|
import numpy as np
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def autograd_fallback_mode(mode):
|
||||||
|
prev = torch._C._get_autograd_fallback_mode()
|
||||||
|
try:
|
||||||
|
torch._C._set_autograd_fallback_mode(mode)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch._C._set_autograd_fallback_mode(prev)
|
||||||
|
|
||||||
|
class TestAutogradFallback(TestCase):
|
||||||
|
test_ns = '_test_autograd_fallback'
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
if hasattr(torch.ops, self.test_ns):
|
||||||
|
delattr(torch.ops, self.test_ns)
|
||||||
|
if hasattr(self, 'lib'):
|
||||||
|
del self.lib.m
|
||||||
|
del self.lib
|
||||||
|
|
||||||
|
def get_op(self, name):
|
||||||
|
return getattr(getattr(torch.ops, self.test_ns), name).default
|
||||||
|
|
||||||
|
def get_lib(self):
|
||||||
|
lib = Library(self.test_ns, "FRAGMENT")
|
||||||
|
self.lib = lib
|
||||||
|
return lib
|
||||||
|
|
||||||
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
|
def test_no_grad(self, mode):
|
||||||
|
with autograd_fallback_mode(mode):
|
||||||
|
lib = self.get_lib()
|
||||||
|
lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
|
||||||
|
lib.impl("foo", lambda a, b, c: a + b + c, "CPU")
|
||||||
|
op = self.get_op("foo")
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("error")
|
||||||
|
with torch.no_grad():
|
||||||
|
a = torch.randn([], requires_grad=True)
|
||||||
|
b = torch.randn([], requires_grad=True)
|
||||||
|
out = op(a, b, 1)
|
||||||
|
self.assertFalse(out.requires_grad)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("error")
|
||||||
|
a = torch.randn([])
|
||||||
|
b = torch.randn([])
|
||||||
|
out = op(a, b, 1)
|
||||||
|
self.assertFalse(out.requires_grad)
|
||||||
|
|
||||||
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
|
def test_no_autograd_kernel(self, mode):
|
||||||
|
with autograd_fallback_mode(mode):
|
||||||
|
lib = self.get_lib()
|
||||||
|
lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
|
||||||
|
op = self.get_op("foo")
|
||||||
|
|
||||||
|
def foo_impl(a, b, c):
|
||||||
|
result = a.detach().numpy() + b.detach().numpy() + c
|
||||||
|
return torch.tensor(result)
|
||||||
|
|
||||||
|
lib.impl("foo", foo_impl, "CPU")
|
||||||
|
|
||||||
|
# Some inputs requiring grad
|
||||||
|
a = torch.randn([], requires_grad=False)
|
||||||
|
b = torch.randn([], requires_grad=True)
|
||||||
|
out = op(a, b, 1).sum()
|
||||||
|
with self._check_ctx(mode, mode_nothing_raises=True):
|
||||||
|
out.backward()
|
||||||
|
self.assertIsNone(b.grad)
|
||||||
|
|
||||||
|
def _check_ctx(self, mode, *, mode_nothing_raises=False):
|
||||||
|
if mode == "warn":
|
||||||
|
return self.assertWarnsRegex(UserWarning, 'an autograd kernel was not registered')
|
||||||
|
assert mode == "nothing"
|
||||||
|
if mode_nothing_raises:
|
||||||
|
return self.assertRaisesRegex(RuntimeError, "does not require grad")
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
|
def test_no_autograd_kernel_inplace(self, mode):
|
||||||
|
with autograd_fallback_mode(mode):
|
||||||
|
# input modified in-place gets returned as output
|
||||||
|
lib = self.get_lib()
|
||||||
|
lib.define("foo(Tensor(a!) self, Tensor(b!) y) -> (Tensor(a!), Tensor(b!))")
|
||||||
|
op = self.get_op("foo")
|
||||||
|
|
||||||
|
def foo_impl(x, y):
|
||||||
|
with torch.no_grad():
|
||||||
|
x.sin_()
|
||||||
|
y.cos_()
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
lib.impl("foo", foo_impl, "CPU")
|
||||||
|
|
||||||
|
x = torch.randn(3, requires_grad=True)
|
||||||
|
w = x.clone()
|
||||||
|
v = x.clone()
|
||||||
|
y0 = w[0]
|
||||||
|
y1 = v[1]
|
||||||
|
z0, z1 = op(y0, y1)
|
||||||
|
for tensor in [w, v, z0, z1, y0, y1]:
|
||||||
|
with self._check_ctx(mode):
|
||||||
|
tensor.sum().backward(retain_graph=True)
|
||||||
|
|
||||||
|
# no outputs: we don't do anything. Maybe we should in the future.
|
||||||
|
# This is not a common failure mode.
|
||||||
|
lib.define("bar(Tensor(a!) self) -> ()")
|
||||||
|
op = self.get_op("bar")
|
||||||
|
|
||||||
|
def bar_impl(x):
|
||||||
|
with torch.no_grad():
|
||||||
|
x.sin_()
|
||||||
|
|
||||||
|
lib.impl("bar", bar_impl, "CPU")
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("error")
|
||||||
|
x = torch.randn([], requires_grad=True)
|
||||||
|
y = x.clone()
|
||||||
|
z = op(y)
|
||||||
|
y.backward()
|
||||||
|
self.assertEqual(x.grad, torch.ones_like(x))
|
||||||
|
|
||||||
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
|
def test_cpu_return_self(self, mode):
|
||||||
|
with autograd_fallback_mode(mode):
|
||||||
|
# To be clear, none of these situations are OK and will lead
|
||||||
|
# to other problems down the line. We're testing them because
|
||||||
|
# it is fairly common to actually do these things.
|
||||||
|
lib = Library(self.test_ns, "FRAGMENT")
|
||||||
|
lib.define("foo(Tensor self) -> Tensor")
|
||||||
|
lib.impl("foo", lambda x: x, "CPU")
|
||||||
|
op = self.get_op("foo")
|
||||||
|
|
||||||
|
x = torch.randn(3, requires_grad=True)
|
||||||
|
y = op(x).sum()
|
||||||
|
with self._check_ctx(mode):
|
||||||
|
y.backward()
|
||||||
|
self.assertEqual(x.grad, torch.ones_like(x))
|
||||||
|
|
||||||
|
lib.define("bar(Tensor(a!) self) -> Tensor(a!)")
|
||||||
|
lib.impl("bar", lambda x: x, "CPU")
|
||||||
|
op = self.get_op("bar")
|
||||||
|
|
||||||
|
x = torch.randn(3, requires_grad=True)
|
||||||
|
y = op(x).sum()
|
||||||
|
with self._check_ctx(mode):
|
||||||
|
y.backward()
|
||||||
|
self.assertEqual(x.grad, torch.ones_like(x))
|
||||||
|
|
||||||
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
|
def test_composite_registered_to_cpu(self, mode):
|
||||||
|
with autograd_fallback_mode(mode):
|
||||||
|
lib = Library(self.test_ns, "FRAGMENT")
|
||||||
|
lib.define("foo(Tensor self) -> Tensor")
|
||||||
|
lib.impl("foo", lambda x: x.sin().sum(), "CPU")
|
||||||
|
op = self.get_op("foo")
|
||||||
|
|
||||||
|
x = torch.randn(3, requires_grad=True)
|
||||||
|
y = op(x)
|
||||||
|
with self._check_ctx(mode):
|
||||||
|
y.backward()
|
||||||
|
self.assertEqual(x.grad, x.cos())
|
||||||
|
|
||||||
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
|
def test_autograd_function_registered_to_cpu(self, mode):
|
||||||
|
with autograd_fallback_mode(mode):
|
||||||
|
lib = Library(self.test_ns, "FRAGMENT")
|
||||||
|
lib.define("foo(Tensor self) -> Tensor")
|
||||||
|
|
||||||
|
class NumpySin(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return torch.tensor(np.sin(x.cpu().numpy()))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, gx):
|
||||||
|
x, = ctx.saved_tensors
|
||||||
|
return gx * x.cos()
|
||||||
|
|
||||||
|
lib.impl("foo", NumpySin.apply, "CPU")
|
||||||
|
op = self.get_op("foo")
|
||||||
|
|
||||||
|
x = torch.randn(3, requires_grad=True)
|
||||||
|
y = op(x).sum()
|
||||||
|
with self._check_ctx(mode):
|
||||||
|
y.backward()
|
||||||
|
self.assertEqual(x.grad, x.cos())
|
||||||
|
|
||||||
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
|
def test_inplace_autograd_function_registered_to_cpu(self, mode):
|
||||||
|
with autograd_fallback_mode(mode):
|
||||||
|
lib = Library(self.test_ns, "FRAGMENT")
|
||||||
|
lib.define("foo(Tensor(a!) self) -> Tensor(a!)")
|
||||||
|
|
||||||
|
class NumpySin_(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x.clone())
|
||||||
|
x_np = x.detach().numpy()
|
||||||
|
np.sin(x_np, out=x_np)
|
||||||
|
ctx.mark_dirty(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, gx):
|
||||||
|
x, = ctx.saved_tensors
|
||||||
|
return gx * x.cos()
|
||||||
|
|
||||||
|
lib.impl("foo", NumpySin_.apply, "CPU")
|
||||||
|
op = self.get_op("foo")
|
||||||
|
|
||||||
|
x = torch.randn(3, requires_grad=True)
|
||||||
|
z = x.clone()
|
||||||
|
w = z[0]
|
||||||
|
y = op(w)
|
||||||
|
|
||||||
|
expected = torch.zeros_like(x)
|
||||||
|
expected[0] = x[0].cos()
|
||||||
|
with self._check_ctx(mode):
|
||||||
|
gx, = torch.autograd.grad(y, x, torch.ones_like(y), retain_graph=True)
|
||||||
|
self.assertEqual(gx, expected)
|
||||||
|
|
||||||
|
expected = torch.ones_like(x)
|
||||||
|
expected[0] = x[0].cos()
|
||||||
|
with self._check_ctx(mode):
|
||||||
|
gx, = torch.autograd.grad(z, x, torch.ones_like(z))
|
||||||
|
self.assertEqual(gx, expected)
|
||||||
|
|
||||||
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
|
def test_inplace_on_tensor_that_does_not_require_grad(self, mode):
|
||||||
|
# We don't do anything special (that is, we don't rebase history).
|
||||||
|
# See NOTE [autograd fallback and in-place operations] for why
|
||||||
|
with autograd_fallback_mode(mode):
|
||||||
|
lib = Library(self.test_ns, "FRAGMENT")
|
||||||
|
|
||||||
|
# Correct usage of (a!)
|
||||||
|
lib.define("foo(Tensor(a!) self, Tensor other) -> Tensor(a!)")
|
||||||
|
|
||||||
|
def foo_impl(x, y):
|
||||||
|
x_d = x.detach()
|
||||||
|
y = y.detach()
|
||||||
|
x_d.add_(y)
|
||||||
|
return x
|
||||||
|
|
||||||
|
lib.impl("foo", foo_impl, "CPU")
|
||||||
|
foo = self.get_op("foo")
|
||||||
|
|
||||||
|
# Incorrect usage of (a!): user doesn't return tensor as-is
|
||||||
|
lib.define("bar(Tensor(a!) self, Tensor other) -> Tensor(a!)")
|
||||||
|
|
||||||
|
def bar_impl(x, y):
|
||||||
|
x_d = x.detach()
|
||||||
|
y = y.detach()
|
||||||
|
x_d.add_(y)
|
||||||
|
return x_d.clone()
|
||||||
|
|
||||||
|
lib.impl("bar", bar_impl, "CPU")
|
||||||
|
bar = self.get_op("bar")
|
||||||
|
|
||||||
|
# User mutated input tensor but didn't return it.
|
||||||
|
lib.define("baz(Tensor(a!) self, Tensor other) -> ()")
|
||||||
|
|
||||||
|
def baz_impl(x, y):
|
||||||
|
x_d = x.detach()
|
||||||
|
y = y.detach()
|
||||||
|
x_d.add_(y)
|
||||||
|
|
||||||
|
lib.impl("baz", baz_impl, "CPU")
|
||||||
|
baz = self.get_op("baz")
|
||||||
|
|
||||||
|
# Test in-place on non-view
|
||||||
|
for op in (foo, bar, baz):
|
||||||
|
x = torch.randn(3)
|
||||||
|
y = torch.randn(3, requires_grad=True)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
|
||||||
|
z = x.clone()
|
||||||
|
op(z, y)
|
||||||
|
torch.autograd.grad(z, y, torch.ones_like(z), allow_unused=True)
|
||||||
|
|
||||||
|
# Test in-place on view
|
||||||
|
for op in (foo, bar, baz):
|
||||||
|
x = torch.randn(3)
|
||||||
|
y = torch.randn(3, requires_grad=True)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
|
||||||
|
z = x[:]
|
||||||
|
op(z, y)
|
||||||
|
torch.autograd.grad(z, x, torch.ones_like(z), allow_unused=True)
|
||||||
|
|
||||||
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
|
def test_post_autograd_returns_leaf(self, mode):
|
||||||
|
with autograd_fallback_mode(mode):
|
||||||
|
lib = self.get_lib()
|
||||||
|
lib.define("foo(Tensor a) -> (Tensor, Tensor)")
|
||||||
|
op = self.get_op("foo")
|
||||||
|
|
||||||
|
lib.impl("foo", lambda a: (a.clone(), a.clone().detach().requires_grad_()), "CPU")
|
||||||
|
x = torch.randn(3, requires_grad=True)
|
||||||
|
y, z = op(x)
|
||||||
|
with self._check_ctx(mode):
|
||||||
|
z.sum().backward()
|
||||||
|
|
||||||
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
|
def test_post_autograd_returns_mix_of_requires_grad_tensors(self, mode):
|
||||||
|
with autograd_fallback_mode(mode):
|
||||||
|
lib = self.get_lib()
|
||||||
|
lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor, Tensor)")
|
||||||
|
op = self.get_op("foo")
|
||||||
|
|
||||||
|
def foo_impl(a, b):
|
||||||
|
with torch.no_grad():
|
||||||
|
x = a.clone()
|
||||||
|
z = b.clone()
|
||||||
|
y = a * b
|
||||||
|
return x, y, z
|
||||||
|
|
||||||
|
lib.impl("foo", foo_impl, "CPU")
|
||||||
|
a = torch.randn(3, requires_grad=True)
|
||||||
|
b = torch.randn(3, requires_grad=True)
|
||||||
|
x, y, z = op(a, b)
|
||||||
|
|
||||||
|
with self._check_ctx(mode, mode_nothing_raises=True):
|
||||||
|
torch.autograd.grad(x, (a, b), torch.ones_like(x), allow_unused=True, retain_graph=True)
|
||||||
|
|
||||||
|
with self._check_ctx(mode, mode_nothing_raises=False):
|
||||||
|
torch.autograd.grad(y, (a, b), torch.ones_like(y), allow_unused=True, retain_graph=True)
|
||||||
|
|
||||||
|
with self._check_ctx(mode, mode_nothing_raises=True):
|
||||||
|
torch.autograd.grad(z, (a, b), torch.ones_like(z), allow_unused=True, retain_graph=True)
|
||||||
|
|
||||||
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
|
def test_supports_tensor_lists(self, mode):
|
||||||
|
with autograd_fallback_mode(mode):
|
||||||
|
lib = self.get_lib()
|
||||||
|
lib.define("foo(Tensor[] a) -> Tensor[]")
|
||||||
|
op = self.get_op("foo")
|
||||||
|
|
||||||
|
def foo_impl(a):
|
||||||
|
x, y, z = a
|
||||||
|
with torch.no_grad():
|
||||||
|
return x + y + z, x * y * z
|
||||||
|
|
||||||
|
lib.impl("foo", foo_impl, "CPU")
|
||||||
|
x = torch.randn(3, requires_grad=True)
|
||||||
|
y = torch.randn(1, requires_grad=True)
|
||||||
|
z = torch.randn(2, 1, requires_grad=True)
|
||||||
|
a, b = op([x, y, z])
|
||||||
|
with self._check_ctx(mode, mode_nothing_raises=True):
|
||||||
|
torch.autograd.grad(a, (x, y, z), torch.ones_like(a), allow_unused=True, retain_graph=True)
|
||||||
|
with self._check_ctx(mode, mode_nothing_raises=True):
|
||||||
|
torch.autograd.grad(b, (x, y, z), torch.ones_like(b), allow_unused=True, retain_graph=True)
|
||||||
|
|
||||||
|
|
||||||
|
instantiate_parametrized_tests(TestAutogradFallback)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_tests()
|
||||||
|
|
@ -11267,6 +11267,7 @@ class TestAutogradMultipleDispatch(TestCase):
|
||||||
|
|
||||||
from autograd.test_complex import TestAutogradComplex # noqa: F401
|
from autograd.test_complex import TestAutogradComplex # noqa: F401
|
||||||
from autograd.test_functional import TestAutogradFunctional # noqa: F401
|
from autograd.test_functional import TestAutogradFunctional # noqa: F401
|
||||||
|
from autograd.test_fallback import TestAutogradFallback # noqa: F401
|
||||||
|
|
||||||
# e.g., TestAutogradDeviceTypeCPU and TestAutogradDeviceTypeCUDA
|
# e.g., TestAutogradDeviceTypeCPU and TestAutogradDeviceTypeCUDA
|
||||||
instantiate_device_type_tests(
|
instantiate_device_type_tests(
|
||||||
|
|
|
||||||
|
|
@ -391,7 +391,7 @@ CPU: impl_t_t [kernel]
|
||||||
CUDA: default_def_name_t_t [math kernel]
|
CUDA: default_def_name_t_t [math kernel]
|
||||||
XLA: default_def_name_t_t [math kernel]
|
XLA: default_def_name_t_t [math kernel]
|
||||||
AutogradOther: default_def_name_t_t [math kernel]
|
AutogradOther: default_def_name_t_t [math kernel]
|
||||||
AutogradCPU: fallthrough registered in pytorch framework [backend fallback]
|
AutogradCPU: registered in pytorch framework [backend fallback]
|
||||||
AutogradCUDA: default_def_name_t_t [math kernel]
|
AutogradCUDA: default_def_name_t_t [math kernel]
|
||||||
AutogradXLA: default_def_name_t_t [math kernel]
|
AutogradXLA: default_def_name_t_t [math kernel]
|
||||||
''')
|
''')
|
||||||
|
|
@ -456,7 +456,7 @@ CPU: fn_cpu [kernel]
|
||||||
CUDA: fn_math [math kernel]
|
CUDA: fn_math [math kernel]
|
||||||
XLA: fn_math [math kernel]
|
XLA: fn_math [math kernel]
|
||||||
AutogradOther: fn_math [math kernel]
|
AutogradOther: fn_math [math kernel]
|
||||||
AutogradCPU: fallthrough registered in pytorch framework [backend fallback]
|
AutogradCPU: registered in pytorch framework [backend fallback]
|
||||||
AutogradCUDA: fn_math [math kernel]
|
AutogradCUDA: fn_math [math kernel]
|
||||||
AutogradXLA: fn_math [math kernel]
|
AutogradXLA: fn_math [math kernel]
|
||||||
''')
|
''')
|
||||||
|
|
@ -587,10 +587,10 @@ Undefined: fn_defaultbackend [default backend kernel]
|
||||||
CPU: fn_cpu [kernel]
|
CPU: fn_cpu [kernel]
|
||||||
CUDA: fn_defaultbackend [default backend kernel]
|
CUDA: fn_defaultbackend [default backend kernel]
|
||||||
XLA: fn_defaultbackend [default backend kernel]
|
XLA: fn_defaultbackend [default backend kernel]
|
||||||
AutogradOther: fallthrough registered in pytorch framework [backend fallback]
|
AutogradOther: registered in pytorch framework [backend fallback]
|
||||||
AutogradCPU: fallthrough registered in pytorch framework [backend fallback]
|
AutogradCPU: registered in pytorch framework [backend fallback]
|
||||||
AutogradCUDA: fallthrough registered in pytorch framework [backend fallback]
|
AutogradCUDA: registered in pytorch framework [backend fallback]
|
||||||
AutogradXLA: fallthrough registered in pytorch framework [backend fallback]
|
AutogradXLA: registered in pytorch framework [backend fallback]
|
||||||
''')
|
''')
|
||||||
|
|
||||||
def test_computed_table_with_cpu_autograd_defaultbackend(self):
|
def test_computed_table_with_cpu_autograd_defaultbackend(self):
|
||||||
|
|
@ -814,9 +814,9 @@ XLA fn_XLA [kernel]
|
||||||
Lazy fn_Lazy [kernel]
|
Lazy fn_Lazy [kernel]
|
||||||
FPGA fn_CompositeImplicitAutograd [math kernel]
|
FPGA fn_CompositeImplicitAutograd [math kernel]
|
||||||
AutogradOther fn_CompositeImplicitAutograd [math kernel]
|
AutogradOther fn_CompositeImplicitAutograd [math kernel]
|
||||||
AutogradCPU fallthrough [backend fallback]
|
AutogradCPU [backend fallback]
|
||||||
AutogradXLA fallthrough [backend fallback]
|
AutogradXLA [backend fallback]
|
||||||
AutogradLazy fallthrough [backend fallback]
|
AutogradLazy [backend fallback]
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -836,8 +836,8 @@ Lazy fn_Lazy [kernel]
|
||||||
FPGA fn_CompositeImplicitAutograd [math kernel]
|
FPGA fn_CompositeImplicitAutograd [math kernel]
|
||||||
AutogradOther fn_CompositeImplicitAutograd [math kernel]
|
AutogradOther fn_CompositeImplicitAutograd [math kernel]
|
||||||
AutogradCPU fn_AutogradCPU [kernel]
|
AutogradCPU fn_AutogradCPU [kernel]
|
||||||
AutogradXLA fallthrough [backend fallback]
|
AutogradXLA [backend fallback]
|
||||||
AutogradLazy fallthrough [backend fallback]
|
AutogradLazy [backend fallback]
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
|
|
@ -869,10 +869,10 @@ CPU fn_CPU [kernel]
|
||||||
XLA fn_XLA [kernel]
|
XLA fn_XLA [kernel]
|
||||||
Lazy fn_Lazy [kernel]
|
Lazy fn_Lazy [kernel]
|
||||||
FPGA fn_CompositeExplicitAutograd [default backend kernel]
|
FPGA fn_CompositeExplicitAutograd [default backend kernel]
|
||||||
AutogradOther fallthrough [backend fallback]
|
AutogradOther [backend fallback]
|
||||||
AutogradCPU fn_AutogradCPU [kernel]
|
AutogradCPU fn_AutogradCPU [kernel]
|
||||||
AutogradXLA fallthrough [backend fallback]
|
AutogradXLA [backend fallback]
|
||||||
AutogradLazy fallthrough [backend fallback]
|
AutogradLazy [backend fallback]
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -906,7 +906,7 @@ XLA fn_CompositeImplicitAutograd [math kernel]
|
||||||
Lazy fn_CompositeImplicitAutograd [math kernel]
|
Lazy fn_CompositeImplicitAutograd [math kernel]
|
||||||
FPGA fn_FPGA [kernel]
|
FPGA fn_FPGA [kernel]
|
||||||
AutogradOther ambiguous_autogradother [ambiguous autogradother]
|
AutogradOther ambiguous_autogradother [ambiguous autogradother]
|
||||||
AutogradCPU fallthrough [backend fallback]
|
AutogradCPU [backend fallback]
|
||||||
AutogradXLA fn_CompositeImplicitAutograd [math kernel]
|
AutogradXLA fn_CompositeImplicitAutograd [math kernel]
|
||||||
AutogradLazy fn_CompositeImplicitAutograd [math kernel]
|
AutogradLazy fn_CompositeImplicitAutograd [math kernel]
|
||||||
'''
|
'''
|
||||||
|
|
|
||||||
|
|
@ -1230,6 +1230,9 @@ class _InferenceMode:
|
||||||
def __enter__(self): ...
|
def __enter__(self): ...
|
||||||
def __exit__(self, exc_type, exc_value, traceback): ...
|
def __exit__(self, exc_type, exc_value, traceback): ...
|
||||||
|
|
||||||
|
def _set_autograd_fallback_mode(mode: str) -> None: ...
|
||||||
|
def _get_autograd_fallback_mode() -> str: ...
|
||||||
|
|
||||||
# Defined in torch/csrc/jit/python/script_init.cpp
|
# Defined in torch/csrc/jit/python/script_init.cpp
|
||||||
class LoggerBase: ...
|
class LoggerBase: ...
|
||||||
class NoopLogger(LoggerBase): ...
|
class NoopLogger(LoggerBase): ...
|
||||||
|
|
|
||||||
|
|
@ -45,8 +45,191 @@ void _foreach_tensor(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AutogradFallbackMode kAutogradFallbackMode = AutogradFallbackMode::Warn;
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
void setAutogradFallbackMode(AutogradFallbackMode mode) {
|
||||||
|
TORCH_CHECK(mode != AutogradFallbackMode::Error, "NYI: mode='error'");
|
||||||
|
kAutogradFallbackMode = mode;
|
||||||
|
}
|
||||||
|
|
||||||
|
AutogradFallbackMode getAutogradFallbackMode() {
|
||||||
|
return kAutogradFallbackMode;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void warnAutogradNotImplemented(const std::string& op_name) {
|
||||||
|
TORCH_WARN(
|
||||||
|
op_name,
|
||||||
|
": an autograd kernel was not registered to the Autograd key(s) ",
|
||||||
|
"but we are trying to backprop through it. This may lead to silently incorrect behavior. ",
|
||||||
|
"This behavior is deprecated and will be removed in a future version of PyTorch. ",
|
||||||
|
"If your operator is differentiable, please ensure you have registered an "
|
||||||
|
"autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, "
|
||||||
|
"DispatchKey::CompositeImplicitAutograd). If your operator is not "
|
||||||
|
"differentiable, or to squash this warning and use the previous behavior, "
|
||||||
|
"please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.");
|
||||||
|
}
|
||||||
|
|
||||||
|
struct WarnNotImplemented : public Node {
|
||||||
|
WarnNotImplemented(
|
||||||
|
std::string op_name,
|
||||||
|
int64_t num_outputs,
|
||||||
|
edge_list&& next_edges)
|
||||||
|
: Node(std::move(next_edges)),
|
||||||
|
op_name(std::move(op_name)),
|
||||||
|
num_outputs(num_outputs) {}
|
||||||
|
|
||||||
|
WarnNotImplemented(std::string op_name, int64_t num_outputs)
|
||||||
|
: op_name(std::move(op_name)), num_outputs(num_outputs) {}
|
||||||
|
|
||||||
|
variable_list apply(variable_list&& inputs) override;
|
||||||
|
|
||||||
|
std::string op_name;
|
||||||
|
int64_t num_outputs;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto WarnNotImplemented::apply(variable_list&& inputs) -> variable_list {
|
||||||
|
warnAutogradNotImplemented(op_name);
|
||||||
|
std::vector<at::Tensor> output(num_outputs);
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void basicAutogradNotImplementedFallbackImpl(
|
||||||
|
const c10::OperatorHandle& op,
|
||||||
|
c10::DispatchKeySet dispatch_keys,
|
||||||
|
torch::jit::Stack* stack) {
|
||||||
|
const auto& schema = op.schema();
|
||||||
|
const auto& op_name = schema.operator_name().name;
|
||||||
|
const auto num_arguments = schema.arguments().size();
|
||||||
|
const auto num_returns = schema.returns().size();
|
||||||
|
const auto stack_start = stack->size() - num_arguments;
|
||||||
|
const bool grad_mode = GradMode::is_enabled();
|
||||||
|
|
||||||
|
if (getAutogradFallbackMode() == AutogradFallbackMode::Nothing) {
|
||||||
|
op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
getAutogradFallbackMode() == AutogradFallbackMode::Warn);
|
||||||
|
|
||||||
|
bool any_input_requires_grad = false;
|
||||||
|
if (grad_mode) {
|
||||||
|
_foreach_tensor(
|
||||||
|
[&](size_t _, size_t idx_arg, const at::Tensor& t) {
|
||||||
|
if (t.requires_grad()) {
|
||||||
|
any_input_requires_grad = true;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
stack,
|
||||||
|
stack_start,
|
||||||
|
num_arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<WarnNotImplemented> grad_fn;
|
||||||
|
if (any_input_requires_grad) {
|
||||||
|
// NB: It is standard to collect edges from all tensors
|
||||||
|
// (see generated/VariableTypeEverything.cpp for examples)
|
||||||
|
std::vector<const at::Tensor*> all_tensors_on_stack;
|
||||||
|
_foreach_tensor(
|
||||||
|
[&](size_t _, size_t idx_arg, const at::Tensor& t) {
|
||||||
|
all_tensors_on_stack.push_back(&t);
|
||||||
|
},
|
||||||
|
stack,
|
||||||
|
stack_start,
|
||||||
|
num_arguments);
|
||||||
|
grad_fn = std::shared_ptr<WarnNotImplemented>(
|
||||||
|
new WarnNotImplemented(op_name, all_tensors_on_stack.size()),
|
||||||
|
deleteNode);
|
||||||
|
grad_fn->set_next_edges(collect_next_edges(all_tensors_on_stack));
|
||||||
|
}
|
||||||
|
|
||||||
|
op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
|
||||||
|
|
||||||
|
if (any_input_requires_grad) {
|
||||||
|
// NB: if the operator mutates any inputs in-place and does not return them
|
||||||
|
// as outputs, we are unable to lazily raise a warning. This is OK because
|
||||||
|
// we don't expect many existing operators to do this because of the amount
|
||||||
|
// of technical expertise necessary (you would need to manually register an
|
||||||
|
// autograd kernel without using autograd.Function)
|
||||||
|
_foreach_tensor(
|
||||||
|
[&](size_t _, size_t idx_ret, const at::Tensor& t) {
|
||||||
|
if (!isDifferentiableType(t.scalar_type())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const bool is_mutable_output =
|
||||||
|
schema.is_aliasing({c10::SchemaArgType::output, idx_ret}) &&
|
||||||
|
schema.is_mutable({c10::SchemaArgType::output, idx_ret});
|
||||||
|
|
||||||
|
// If the post-autograd implementation returns Tensors that require
|
||||||
|
// grad, then we install a hook that will warn during the backwards.
|
||||||
|
//
|
||||||
|
// NB: If the operation is inplace and the inputs were views,
|
||||||
|
// it is possible that the history was rebased and the hook will
|
||||||
|
// not warn in all places where it should. That is, the following
|
||||||
|
// won't warn:
|
||||||
|
// >>> x = torch.randn(3, 3, requires_grad=True)
|
||||||
|
// >>> z = x.clone()
|
||||||
|
// >>> w = z[0]
|
||||||
|
// >>> k = w[0]
|
||||||
|
// >>> y = op(k)
|
||||||
|
// >>> torch.autograd.grad(z.sum(), w)
|
||||||
|
if (t.requires_grad()) {
|
||||||
|
t.register_hook([op_name](const at::Tensor& grad) {
|
||||||
|
warnAutogradNotImplemented(op_name);
|
||||||
|
return grad;
|
||||||
|
});
|
||||||
|
// If history is rebased, then we will attempt to warn
|
||||||
|
// on the view's base. This will catch most cases (because
|
||||||
|
// users typically call .backward() and backprop through
|
||||||
|
// the entire program).
|
||||||
|
if (t.is_view() && is_mutable_output) {
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||||
|
const_cast<at::TensorBase&>(t._base()).register_hook(
|
||||||
|
[op_name](const at::TensorBase& grad) {
|
||||||
|
warnAutogradNotImplemented(op_name);
|
||||||
|
return grad;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the post-autograd implementation returns any Tensors that
|
||||||
|
// don't require grad, then we install the WarnNotImplemented grad_fn.
|
||||||
|
// This grad_fn warns in backward and returns undefined tensor
|
||||||
|
// gradients.
|
||||||
|
//
|
||||||
|
// NOTE [autograd fallback and in-place operations]
|
||||||
|
// If the schema says the output is mutable, and the output
|
||||||
|
// is an input, and the input is a view Tensor, then...
|
||||||
|
// we're not sure if set_history is OK to do, so we just skip
|
||||||
|
// adding the grad_fn. Builtin operators do rebase_history here,
|
||||||
|
// but custom operators may have multiple Tensor(a!) returns,
|
||||||
|
// rebase_history assumes single Tensor(a!) return, and in general
|
||||||
|
// custom ops don't have a good in-place story.
|
||||||
|
if (!is_mutable_output) {
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||||
|
set_history(const_cast<at::Tensor&>(t), grad_fn);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
stack,
|
||||||
|
stack->size() - num_returns,
|
||||||
|
num_returns);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::CppFunction basicAutogradNotImplementedFallback() {
|
||||||
|
return torch::CppFunction::makeFromBoxedFunction<
|
||||||
|
&basicAutogradNotImplementedFallbackImpl>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void VariableHooks::basic_autograd_not_implemented_fallback(
|
||||||
|
const c10::OperatorHandle& op,
|
||||||
|
c10::DispatchKeySet dispatch_keys,
|
||||||
|
torch::jit::Stack* stack) const {
|
||||||
|
basicAutogradNotImplementedFallbackImpl(op, dispatch_keys, stack);
|
||||||
|
}
|
||||||
|
|
||||||
static void autogradNotImplementedFallbackImpl(
|
static void autogradNotImplementedFallbackImpl(
|
||||||
const c10::OperatorHandle& op,
|
const c10::OperatorHandle& op,
|
||||||
c10::DispatchKeySet dispatch_keys,
|
c10::DispatchKeySet dispatch_keys,
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,30 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace autograd {
|
namespace autograd {
|
||||||
|
|
||||||
|
// Default DispatchKey::Autograd fallback for built-in operators.
|
||||||
|
// Can be registered for custom operators.
|
||||||
TORCH_API torch::CppFunction autogradNotImplementedFallback();
|
TORCH_API torch::CppFunction autogradNotImplementedFallback();
|
||||||
|
|
||||||
|
// Default DispatchKey::AdInplaceOrView fallback for built-in operators
|
||||||
|
// Can be registered for custom operators.
|
||||||
TORCH_API torch::CppFunction autogradNotImplementedInplaceOrViewFallback();
|
TORCH_API torch::CppFunction autogradNotImplementedInplaceOrViewFallback();
|
||||||
|
|
||||||
|
// Default DispatchKey::Autograd fallback for all other operators (i.e. custom
|
||||||
|
// operators)
|
||||||
|
TORCH_API torch::CppFunction basicAutogradNotImplementedFallback();
|
||||||
|
|
||||||
|
enum class AutogradFallbackMode {
|
||||||
|
Nothing, // Fallback is a redispatch
|
||||||
|
Warn, // Fallback raises a warning if backward is called
|
||||||
|
Error, // Fallback raises an error if backward is called
|
||||||
|
};
|
||||||
|
|
||||||
|
// Change the behavior of "basicAutogradNotImplementedFallback"
|
||||||
|
// In Python this is:
|
||||||
|
// - torch._C._set_autograd_fallback_mode(str) -> None
|
||||||
|
// - torch._C._get_autograd_fallback_mode() -> str
|
||||||
|
TORCH_API void setAutogradFallbackMode(AutogradFallbackMode mode);
|
||||||
|
TORCH_API AutogradFallbackMode getAutogradFallbackMode();
|
||||||
|
|
||||||
} // namespace autograd
|
} // namespace autograd
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@
|
||||||
#include <torch/csrc/Exceptions.h>
|
#include <torch/csrc/Exceptions.h>
|
||||||
#include <torch/csrc/autograd/VariableTypeUtils.h>
|
#include <torch/csrc/autograd/VariableTypeUtils.h>
|
||||||
#include <torch/csrc/autograd/autograd.h>
|
#include <torch/csrc/autograd/autograd.h>
|
||||||
|
#include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
|
||||||
#include <torch/csrc/autograd/function.h>
|
#include <torch/csrc/autograd/function.h>
|
||||||
#include <torch/csrc/autograd/grad_mode.h>
|
#include <torch/csrc/autograd/grad_mode.h>
|
||||||
#include <torch/csrc/autograd/profiler.h>
|
#include <torch/csrc/autograd/profiler.h>
|
||||||
|
|
@ -408,6 +409,37 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||||
auto cls = python_type_class.ptr();
|
auto cls = python_type_class.ptr();
|
||||||
registerPythonTensorClass(device, cls);
|
registerPythonTensorClass(device, cls);
|
||||||
});
|
});
|
||||||
|
_C_m.def("_set_autograd_fallback_mode", [](const std::string& mode) {
|
||||||
|
if (mode == "nothing") {
|
||||||
|
torch::autograd::setAutogradFallbackMode(
|
||||||
|
torch::autograd::AutogradFallbackMode::Nothing);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (mode == "warn") {
|
||||||
|
torch::autograd::setAutogradFallbackMode(
|
||||||
|
torch::autograd::AutogradFallbackMode::Warn);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (mode == "error") {
|
||||||
|
torch::autograd::setAutogradFallbackMode(
|
||||||
|
torch::autograd::AutogradFallbackMode::Error);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
TORCH_INTERNAL_ASSERT(false, "Unsupported AutogradFallbackMode: ", mode);
|
||||||
|
});
|
||||||
|
_C_m.def("_get_autograd_fallback_mode", []() {
|
||||||
|
auto mode = torch::autograd::getAutogradFallbackMode();
|
||||||
|
switch (mode) {
|
||||||
|
case torch::autograd::AutogradFallbackMode::Nothing:
|
||||||
|
return "nothing";
|
||||||
|
case torch::autograd::AutogradFallbackMode::Warn:
|
||||||
|
return "warn";
|
||||||
|
case torch::autograd::AutogradFallbackMode::Error:
|
||||||
|
return "error";
|
||||||
|
default:
|
||||||
|
TORCH_INTERNAL_ASSERT(false, "Unsupported AutogradFallbackMode");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
_C_m.def("_activate_cuda_trace", []() { activateCUDATrace(); });
|
_C_m.def("_activate_cuda_trace", []() { activateCUDATrace(); });
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,6 @@
|
||||||
#include <torch/csrc/autograd/generated/Functions.h>
|
#include <torch/csrc/autograd/generated/Functions.h>
|
||||||
#include <torch/csrc/autograd/utils/error_messages.h>
|
#include <torch/csrc/autograd/utils/error_messages.h>
|
||||||
|
|
||||||
#include <ATen/core/VariableHooksInterface.h>
|
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/FuncTorchTLS.h>
|
#include <ATen/FuncTorchTLS.h>
|
||||||
#include <ATen/MemoryOverlap.h>
|
#include <ATen/MemoryOverlap.h>
|
||||||
|
|
@ -394,36 +392,6 @@ DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase& self) {
|
||||||
|
|
||||||
using at::Tensor;
|
using at::Tensor;
|
||||||
|
|
||||||
struct VariableHooks final : at::impl::VariableHooksInterface {
|
|
||||||
at::TensorBase tensor_data(const at::TensorBase&) const override;
|
|
||||||
at::TensorBase variable_data(const at::TensorBase&) const override;
|
|
||||||
const std::shared_ptr<torch::autograd::Node>& grad_fn(
|
|
||||||
const at::TensorBase&) const override;
|
|
||||||
unsigned _register_hook(
|
|
||||||
const at::TensorBase&,
|
|
||||||
std::function<at::TensorBase(const at::TensorBase&)> hook) const override;
|
|
||||||
void remove_hook(const at::TensorBase&, unsigned pos) const override;
|
|
||||||
bool is_view(const at::TensorBase&) const override;
|
|
||||||
const at::TensorBase& base(const at::TensorBase&) const override;
|
|
||||||
const std::string& name(const at::TensorBase&) const override;
|
|
||||||
bool is_leaf(const at::TensorBase&) const override;
|
|
||||||
int64_t output_nr(const at::TensorBase&) const override;
|
|
||||||
void set_data(const at::TensorBase& self, const at::TensorBase& new_data)
|
|
||||||
const override;
|
|
||||||
at::TensorBase data(const at::TensorBase& self) const override;
|
|
||||||
int64_t _version(const at::TensorBase& self) const override;
|
|
||||||
void retain_grad(const at::TensorBase& self) const override;
|
|
||||||
bool retains_grad(const at::TensorBase& self) const override;
|
|
||||||
void _backward(
|
|
||||||
const Tensor& self,
|
|
||||||
at::TensorList inputs,
|
|
||||||
const c10::optional<Tensor>& gradient,
|
|
||||||
c10::optional<bool> keep_graph,
|
|
||||||
bool create_graph) const override;
|
|
||||||
void requires_grad_(const at::TensorBase& self, bool _requires_grad)
|
|
||||||
const override;
|
|
||||||
};
|
|
||||||
|
|
||||||
VariableHooks variableHooks;
|
VariableHooks variableHooks;
|
||||||
at::impl::VariableHooksRegisterer registerVariableHooks(&variableHooks);
|
at::impl::VariableHooksRegisterer registerVariableHooks(&variableHooks);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@
|
||||||
|
|
||||||
#include <ATen/NamedTensorUtils.h>
|
#include <ATen/NamedTensorUtils.h>
|
||||||
#include <ATen/core/Tensor.h>
|
#include <ATen/core/Tensor.h>
|
||||||
|
#include <ATen/core/VariableHooksInterface.h>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
@ -796,6 +797,40 @@ inline Variable make_variable(
|
||||||
return Variable();
|
return Variable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct VariableHooks final : at::impl::VariableHooksInterface {
|
||||||
|
at::TensorBase tensor_data(const at::TensorBase&) const override;
|
||||||
|
at::TensorBase variable_data(const at::TensorBase&) const override;
|
||||||
|
const std::shared_ptr<torch::autograd::Node>& grad_fn(
|
||||||
|
const at::TensorBase&) const override;
|
||||||
|
unsigned _register_hook(
|
||||||
|
const at::TensorBase&,
|
||||||
|
std::function<at::TensorBase(const at::TensorBase&)> hook) const override;
|
||||||
|
void remove_hook(const at::TensorBase&, unsigned pos) const override;
|
||||||
|
bool is_view(const at::TensorBase&) const override;
|
||||||
|
const at::TensorBase& base(const at::TensorBase&) const override;
|
||||||
|
const std::string& name(const at::TensorBase&) const override;
|
||||||
|
bool is_leaf(const at::TensorBase&) const override;
|
||||||
|
int64_t output_nr(const at::TensorBase&) const override;
|
||||||
|
void set_data(const at::TensorBase& self, const at::TensorBase& new_data)
|
||||||
|
const override;
|
||||||
|
at::TensorBase data(const at::TensorBase& self) const override;
|
||||||
|
int64_t _version(const at::TensorBase& self) const override;
|
||||||
|
void retain_grad(const at::TensorBase& self) const override;
|
||||||
|
bool retains_grad(const at::TensorBase& self) const override;
|
||||||
|
void _backward(
|
||||||
|
const at::Tensor& self,
|
||||||
|
at::TensorList inputs,
|
||||||
|
const c10::optional<at::Tensor>& gradient,
|
||||||
|
c10::optional<bool> keep_graph,
|
||||||
|
bool create_graph) const override;
|
||||||
|
void requires_grad_(const at::TensorBase& self, bool _requires_grad)
|
||||||
|
const override;
|
||||||
|
void basic_autograd_not_implemented_fallback(
|
||||||
|
const c10::OperatorHandle& op,
|
||||||
|
c10::DispatchKeySet dispatch_keys,
|
||||||
|
torch::jit::Stack* stack) const override;
|
||||||
|
};
|
||||||
|
|
||||||
namespace utils {
|
namespace utils {
|
||||||
|
|
||||||
TORCH_API bool has_same_meta(const Variable& base, const Variable& other);
|
TORCH_API bool has_same_meta(const Variable& base, const Variable& other);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user