mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make PT2 compile backprop through custom op without autograd key a hard error (#166367)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/166367 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
8f40a0c634
commit
4acc66f119
|
|
@ -414,6 +414,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
model = Model().to(self.device)
|
||||
model.emb.weight.requires_grad = False
|
||||
model_compiled = torch.compile(model)
|
||||
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device=self.device)
|
||||
out = model_compiled(inp, self.world_size, **self.get_world_trs())
|
||||
|
|
@ -1340,13 +1341,11 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||
assert counter.op_count == 3 # It generates 2 getattr to unpack the array
|
||||
assert same(out, correct)
|
||||
|
||||
# This doesn't work in all cases, and now we properly loudly error.
|
||||
# See: https://github.com/pytorch/pytorch/issues/151240
|
||||
# When differentiable funcols are implemented can revert.
|
||||
@unittest.expectedFailure
|
||||
def test_backwards(self):
|
||||
"""
|
||||
It's probably not that common to need backwards support for collectives.
|
||||
|
||||
However, I wanted to at least see if it was possible to support it as a design goal.
|
||||
"""
|
||||
|
||||
def func(inp):
|
||||
ar = _functional_collectives.all_reduce(inp, "sum", "0")
|
||||
return ar
|
||||
|
|
|
|||
|
|
@ -9757,6 +9757,17 @@ def ___make_guard_fn():
|
|||
def foo_impl(x, y):
|
||||
return torch.cat([x, y])
|
||||
|
||||
def setup_context(ctx, inputs, output):
|
||||
(x, _) = inputs
|
||||
ctx.xs = x.shape[0]
|
||||
|
||||
def foo_backward(ctx, grad):
|
||||
return grad[: ctx.xs], grad[ctx.xs :]
|
||||
|
||||
torch.library.register_autograd(
|
||||
"mylib::foo", foo_backward, setup_context=setup_context
|
||||
)
|
||||
|
||||
@torch.compile(backend="aot_eager", fullgraph=True)
|
||||
def f(x, i):
|
||||
i0, i1 = i.tolist()
|
||||
|
|
|
|||
|
|
@ -1254,6 +1254,8 @@ def forward(self, x_1: "f32[2][1]cpu"):
|
|||
torch._dynamo.reset()
|
||||
|
||||
mod = SimpleModule().cuda()
|
||||
for p in mod.parameters():
|
||||
p.requires_grad = False
|
||||
compiled = torch.compile(mod, backend="inductor")
|
||||
compiled(torch.randn(4, 4, device="cuda"))
|
||||
|
||||
|
|
@ -1321,6 +1323,8 @@ def forward(self, x_1: "f32[2][1]cpu"):
|
|||
torch._dynamo.reset()
|
||||
|
||||
mod = MixedModule().cuda()
|
||||
for p in mod.parameters():
|
||||
p.requires_grad = False
|
||||
compiled = torch.compile(mod, backend="inductor")
|
||||
compiled(torch.randn(4, 4, device="cuda"))
|
||||
|
||||
|
|
@ -1375,6 +1379,8 @@ def forward(self, x_1: "f32[2][1]cpu"):
|
|||
with self._setup_runtime_estimates_capture() as payload_buffer:
|
||||
torch._dynamo.reset()
|
||||
mod = Mixed().cuda()
|
||||
for p in mod.parameters():
|
||||
p.requires_grad = False
|
||||
compiled = torch.compile(mod, backend="inductor")
|
||||
compiled(torch.randn(4, 4, device="cuda"))
|
||||
payload = payload_buffer.getvalue().strip()
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import warnings
|
|||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch._library.autograd import autograd_fallback_mode
|
||||
from torch.library import _scoped_library
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
|
|
@ -15,16 +16,6 @@ from torch.testing._internal.common_utils import (
|
|||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def autograd_fallback_mode(mode):
|
||||
prev = torch._C._get_autograd_fallback_mode()
|
||||
try:
|
||||
torch._C._set_autograd_fallback_mode(mode)
|
||||
yield
|
||||
finally:
|
||||
torch._C._set_autograd_fallback_mode(prev)
|
||||
|
||||
|
||||
class TestAutogradFallback(TestCase):
|
||||
test_ns = "_test_autograd_fallback"
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from torch._dynamo.utils import (
|
|||
from torch._guards import detect_fake_mode
|
||||
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
|
||||
from torch._inductor.utils import BoxedBool
|
||||
from torch._library.autograd import autograd_fallback_mode
|
||||
from torch._subclasses import FakeTensor, FakeTensorMode
|
||||
from torch.export._tree_utils import reorder_kwargs
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
|
@ -528,6 +529,9 @@ def create_aot_state(
|
|||
stack.enter_context(
|
||||
torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing()
|
||||
)
|
||||
# Make it an error to backprop through PT2 compliant ops that silently
|
||||
# detach autograd
|
||||
stack.enter_context(autograd_fallback_mode("error"))
|
||||
|
||||
from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import dataclasses
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -235,6 +236,16 @@ def not_list_of_optional_tensor(tree):
|
|||
return True
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def autograd_fallback_mode(mode):
|
||||
prev = _C._get_autograd_fallback_mode()
|
||||
try:
|
||||
_C._set_autograd_fallback_mode(mode)
|
||||
yield
|
||||
finally:
|
||||
_C._set_autograd_fallback_mode(prev)
|
||||
|
||||
|
||||
flatten = _pytree.tree_flatten
|
||||
unflatten = _pytree.tree_unflatten
|
||||
spec_t = _pytree.TreeSpec
|
||||
|
|
|
|||
|
|
@ -50,7 +50,6 @@ AutogradFallbackMode kAutogradFallbackMode = AutogradFallbackMode::Warn;
|
|||
} // namespace
|
||||
|
||||
void setAutogradFallbackMode(AutogradFallbackMode mode) {
|
||||
TORCH_CHECK(mode != AutogradFallbackMode::Error, "NYI: mode='error'");
|
||||
kAutogradFallbackMode = mode;
|
||||
}
|
||||
|
||||
|
|
@ -58,7 +57,10 @@ AutogradFallbackMode getAutogradFallbackMode() {
|
|||
return kAutogradFallbackMode;
|
||||
}
|
||||
|
||||
static void warnAutogradNotImplemented(const std::string& op_name) {
|
||||
static void reportAutogradNotImplemented(
|
||||
const std::string& op_name,
|
||||
bool is_warn) {
|
||||
if (is_warn) {
|
||||
TORCH_WARN(
|
||||
op_name,
|
||||
": an autograd kernel was not registered to the Autograd key(s) ",
|
||||
|
|
@ -69,30 +71,47 @@ static void warnAutogradNotImplemented(const std::string& op_name) {
|
|||
"DispatchKey::CompositeImplicitAutograd). If your operator is not "
|
||||
"differentiable, or to squash this warning and use the previous behavior, "
|
||||
"please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.");
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
0,
|
||||
op_name,
|
||||
": an autograd kernel was not registered to the Autograd key(s) ",
|
||||
"but we are trying to backprop through it. This can lead to silently incorrect behavior. ",
|
||||
"If your operator is differentiable, please ensure you have registered an "
|
||||
"autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, "
|
||||
"). If your operator is not "
|
||||
"differentiable and ensure NO gradients flow through this operator, "
|
||||
"please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.")
|
||||
}
|
||||
}
|
||||
|
||||
struct WarnNotImplemented : public Node {
|
||||
WarnNotImplemented(
|
||||
struct NotImplementedBackward : public Node {
|
||||
NotImplementedBackward(
|
||||
std::string op_name,
|
||||
size_t num_outputs,
|
||||
bool is_warn,
|
||||
edge_list&& next_edges)
|
||||
: Node(std::move(next_edges)),
|
||||
op_name(std::move(op_name)),
|
||||
num_outputs(num_outputs) {}
|
||||
num_outputs(num_outputs),
|
||||
is_warn(is_warn) {}
|
||||
|
||||
WarnNotImplemented(std::string op_name, size_t num_outputs)
|
||||
: op_name(std::move(op_name)), num_outputs(num_outputs) {}
|
||||
NotImplementedBackward(std::string op_name, size_t num_outputs, bool is_warn)
|
||||
: op_name(std::move(op_name)),
|
||||
num_outputs(num_outputs),
|
||||
is_warn(is_warn) {}
|
||||
|
||||
variable_list apply(variable_list&& inputs) override;
|
||||
|
||||
std::string op_name;
|
||||
size_t num_outputs;
|
||||
bool is_warn;
|
||||
};
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
|
||||
auto WarnNotImplemented::apply(variable_list&& inputs) -> variable_list {
|
||||
auto NotImplementedBackward::apply(variable_list&& inputs) -> variable_list {
|
||||
auto inputsLocal = std::move(inputs);
|
||||
warnAutogradNotImplemented(op_name);
|
||||
reportAutogradNotImplemented(op_name, is_warn);
|
||||
std::vector<at::Tensor> output(num_outputs);
|
||||
return output;
|
||||
}
|
||||
|
|
@ -111,8 +130,6 @@ static void basicAutogradNotImplementedFallbackImpl(
|
|||
op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
|
||||
return;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
getAutogradFallbackMode() == AutogradFallbackMode::Warn);
|
||||
|
||||
bool any_input_requires_grad = false;
|
||||
_foreach_tensor(
|
||||
|
|
@ -128,7 +145,9 @@ static void basicAutogradNotImplementedFallbackImpl(
|
|||
// by putting it after the requires_grad checks.
|
||||
any_input_requires_grad = any_input_requires_grad && GradMode::is_enabled();
|
||||
|
||||
std::shared_ptr<WarnNotImplemented> grad_fn;
|
||||
bool is_warn = getAutogradFallbackMode() == AutogradFallbackMode::Warn;
|
||||
|
||||
std::shared_ptr<NotImplementedBackward> grad_fn;
|
||||
if (any_input_requires_grad) {
|
||||
// NB: It is standard to collect edges from all tensors
|
||||
// (see generated/VariableTypeEverything.cpp for examples)
|
||||
|
|
@ -140,8 +159,9 @@ static void basicAutogradNotImplementedFallbackImpl(
|
|||
stack,
|
||||
stack_start,
|
||||
num_arguments);
|
||||
grad_fn = std::shared_ptr<WarnNotImplemented>(
|
||||
new WarnNotImplemented(op_name, all_tensors_on_stack.size()),
|
||||
grad_fn = std::shared_ptr<NotImplementedBackward>(
|
||||
new NotImplementedBackward(
|
||||
op_name, all_tensors_on_stack.size(), is_warn),
|
||||
deleteNode);
|
||||
grad_fn->set_next_edges(collect_next_edges(all_tensors_on_stack));
|
||||
}
|
||||
|
|
@ -177,8 +197,8 @@ static void basicAutogradNotImplementedFallbackImpl(
|
|||
// >>> y = op(k)
|
||||
// >>> torch.autograd.grad(z.sum(), w)
|
||||
if (t.requires_grad()) {
|
||||
t.register_hook([op_name](const at::Tensor& grad) {
|
||||
warnAutogradNotImplemented(op_name);
|
||||
t.register_hook([op_name, is_warn](const at::Tensor& grad) {
|
||||
reportAutogradNotImplemented(op_name, is_warn);
|
||||
});
|
||||
// If history is rebased, then we will attempt to warn
|
||||
// on the view's base. This will catch most cases (because
|
||||
|
|
@ -188,8 +208,9 @@ static void basicAutogradNotImplementedFallbackImpl(
|
|||
const auto& base = t._base();
|
||||
if (base.requires_grad()) {
|
||||
// Can only register_hook on tensors that require grad.
|
||||
base.register_hook([op_name](const at::TensorBase& grad) {
|
||||
warnAutogradNotImplemented(op_name);
|
||||
base.register_hook(
|
||||
[op_name, is_warn](const at::TensorBase& grad) {
|
||||
reportAutogradNotImplemented(op_name, is_warn);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -197,9 +218,9 @@ static void basicAutogradNotImplementedFallbackImpl(
|
|||
}
|
||||
|
||||
// If the post-autograd implementation returns any Tensors that
|
||||
// don't require grad, then we install the WarnNotImplemented grad_fn.
|
||||
// This grad_fn warns in backward and returns undefined tensor
|
||||
// gradients.
|
||||
// don't require grad, then we install the NotImplementedBackward
|
||||
// grad_fn. This grad_fn warns in backward and returns undefined
|
||||
// tensor gradients.
|
||||
//
|
||||
// NOTE [autograd fallback and in-place operations]
|
||||
// If the schema says the output is mutable, and the output
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user