Revert "Make PT2 compile backprop through custom op without autograd key a hard error (#166367)"

This reverts commit 4acc66f119.

Reverted https://github.com/pytorch/pytorch/pull/166367 on behalf of https://github.com/atalman due to internal build failures ([comment](https://github.com/pytorch/pytorch/pull/166367#issuecomment-3473150269))
This commit is contained in:
PyTorch MergeBot 2025-10-31 13:44:05 +00:00
parent 4e8ba37ce3
commit 5bcfdae71d
7 changed files with 48 additions and 91 deletions

View File

@ -414,7 +414,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model = Model().to(self.device)
model.emb.weight.requires_grad = False
model_compiled = torch.compile(model)
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device=self.device)
out = model_compiled(inp, self.world_size, **self.get_world_trs())
@ -1341,11 +1340,13 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
assert counter.op_count == 3 # It generates 2 getattr to unpack the array
assert same(out, correct)
# This doesn't work in all cases, and now we properly loudly error.
# See: https://github.com/pytorch/pytorch/issues/151240
# When differentiable funcols are implemented can revert.
@unittest.expectedFailure
def test_backwards(self):
"""
It's probably not that common to need backwards support for collectives.
However, I wanted to at least see if it was possible to support it as a design goal.
"""
def func(inp):
ar = _functional_collectives.all_reduce(inp, "sum", "0")
return ar

View File

@ -9784,17 +9784,6 @@ def ___make_guard_fn():
def foo_impl(x, y):
return torch.cat([x, y])
def setup_context(ctx, inputs, output):
(x, _) = inputs
ctx.xs = x.shape[0]
def foo_backward(ctx, grad):
return grad[: ctx.xs], grad[ctx.xs :]
torch.library.register_autograd(
"mylib::foo", foo_backward, setup_context=setup_context
)
@torch.compile(backend="aot_eager", fullgraph=True)
def f(x, i):
i0, i1 = i.tolist()

View File

@ -1254,8 +1254,6 @@ def forward(self, x_1: "f32[2][1]cpu"):
torch._dynamo.reset()
mod = SimpleModule().cuda()
for p in mod.parameters():
p.requires_grad = False
compiled = torch.compile(mod, backend="inductor")
compiled(torch.randn(4, 4, device="cuda"))
@ -1323,8 +1321,6 @@ def forward(self, x_1: "f32[2][1]cpu"):
torch._dynamo.reset()
mod = MixedModule().cuda()
for p in mod.parameters():
p.requires_grad = False
compiled = torch.compile(mod, backend="inductor")
compiled(torch.randn(4, 4, device="cuda"))
@ -1379,8 +1375,6 @@ def forward(self, x_1: "f32[2][1]cpu"):
with self._setup_runtime_estimates_capture() as payload_buffer:
torch._dynamo.reset()
mod = Mixed().cuda()
for p in mod.parameters():
p.requires_grad = False
compiled = torch.compile(mod, backend="inductor")
compiled(torch.randn(4, 4, device="cuda"))
payload = payload_buffer.getvalue().strip()

View File

@ -6,7 +6,6 @@ import warnings
import numpy as np
import torch
from torch._library.autograd import autograd_fallback_mode
from torch.library import _scoped_library
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -16,6 +15,16 @@ from torch.testing._internal.common_utils import (
)
@contextlib.contextmanager
def autograd_fallback_mode(mode):
prev = torch._C._get_autograd_fallback_mode()
try:
torch._C._set_autograd_fallback_mode(mode)
yield
finally:
torch._C._set_autograd_fallback_mode(prev)
class TestAutogradFallback(TestCase):
test_ns = "_test_autograd_fallback"

View File

@ -26,7 +26,6 @@ from torch._dynamo.utils import (
from torch._guards import detect_fake_mode
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
from torch._inductor.utils import BoxedBool
from torch._library.autograd import autograd_fallback_mode
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.export._tree_utils import reorder_kwargs
from torch.fx.experimental.proxy_tensor import make_fx
@ -529,9 +528,6 @@ def create_aot_state(
stack.enter_context(
torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing()
)
# Make it an error to backprop through PT2 compliant ops that silently
# detach autograd
stack.enter_context(autograd_fallback_mode("error"))
from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
import contextlib
import dataclasses
from collections.abc import Callable
from dataclasses import dataclass
@ -236,16 +235,6 @@ def not_list_of_optional_tensor(tree):
return True
@contextlib.contextmanager
def autograd_fallback_mode(mode):
prev = _C._get_autograd_fallback_mode()
try:
_C._set_autograd_fallback_mode(mode)
yield
finally:
_C._set_autograd_fallback_mode(prev)
flatten = _pytree.tree_flatten
unflatten = _pytree.tree_unflatten
spec_t = _pytree.TreeSpec

View File

@ -50,6 +50,7 @@ AutogradFallbackMode kAutogradFallbackMode = AutogradFallbackMode::Warn;
} // namespace
void setAutogradFallbackMode(AutogradFallbackMode mode) {
TORCH_CHECK(mode != AutogradFallbackMode::Error, "NYI: mode='error'");
kAutogradFallbackMode = mode;
}
@ -57,10 +58,7 @@ AutogradFallbackMode getAutogradFallbackMode() {
return kAutogradFallbackMode;
}
static void reportAutogradNotImplemented(
const std::string& op_name,
bool is_warn) {
if (is_warn) {
static void warnAutogradNotImplemented(const std::string& op_name) {
TORCH_WARN(
op_name,
": an autograd kernel was not registered to the Autograd key(s) ",
@ -71,47 +69,30 @@ static void reportAutogradNotImplemented(
"DispatchKey::CompositeImplicitAutograd). If your operator is not "
"differentiable, or to squash this warning and use the previous behavior, "
"please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.");
} else {
TORCH_CHECK(
0,
op_name,
": an autograd kernel was not registered to the Autograd key(s) ",
"but we are trying to backprop through it. This can lead to silently incorrect behavior. ",
"If your operator is differentiable, please ensure you have registered an "
"autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, "
"). If your operator is not "
"differentiable and ensure NO gradients flow through this operator, "
"please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.")
}
}
struct NotImplementedBackward : public Node {
NotImplementedBackward(
struct WarnNotImplemented : public Node {
WarnNotImplemented(
std::string op_name,
size_t num_outputs,
bool is_warn,
edge_list&& next_edges)
: Node(std::move(next_edges)),
op_name(std::move(op_name)),
num_outputs(num_outputs),
is_warn(is_warn) {}
num_outputs(num_outputs) {}
NotImplementedBackward(std::string op_name, size_t num_outputs, bool is_warn)
: op_name(std::move(op_name)),
num_outputs(num_outputs),
is_warn(is_warn) {}
WarnNotImplemented(std::string op_name, size_t num_outputs)
: op_name(std::move(op_name)), num_outputs(num_outputs) {}
variable_list apply(variable_list&& inputs) override;
std::string op_name;
size_t num_outputs;
bool is_warn;
};
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
auto NotImplementedBackward::apply(variable_list&& inputs) -> variable_list {
auto WarnNotImplemented::apply(variable_list&& inputs) -> variable_list {
auto inputsLocal = std::move(inputs);
reportAutogradNotImplemented(op_name, is_warn);
warnAutogradNotImplemented(op_name);
std::vector<at::Tensor> output(num_outputs);
return output;
}
@ -130,6 +111,8 @@ static void basicAutogradNotImplementedFallbackImpl(
op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
return;
}
TORCH_INTERNAL_ASSERT(
getAutogradFallbackMode() == AutogradFallbackMode::Warn);
bool any_input_requires_grad = false;
_foreach_tensor(
@ -145,9 +128,7 @@ static void basicAutogradNotImplementedFallbackImpl(
// by putting it after the requires_grad checks.
any_input_requires_grad = any_input_requires_grad && GradMode::is_enabled();
bool is_warn = getAutogradFallbackMode() == AutogradFallbackMode::Warn;
std::shared_ptr<NotImplementedBackward> grad_fn;
std::shared_ptr<WarnNotImplemented> grad_fn;
if (any_input_requires_grad) {
// NB: It is standard to collect edges from all tensors
// (see generated/VariableTypeEverything.cpp for examples)
@ -159,9 +140,8 @@ static void basicAutogradNotImplementedFallbackImpl(
stack,
stack_start,
num_arguments);
grad_fn = std::shared_ptr<NotImplementedBackward>(
new NotImplementedBackward(
op_name, all_tensors_on_stack.size(), is_warn),
grad_fn = std::shared_ptr<WarnNotImplemented>(
new WarnNotImplemented(op_name, all_tensors_on_stack.size()),
deleteNode);
grad_fn->set_next_edges(collect_next_edges(all_tensors_on_stack));
}
@ -197,8 +177,8 @@ static void basicAutogradNotImplementedFallbackImpl(
// >>> y = op(k)
// >>> torch.autograd.grad(z.sum(), w)
if (t.requires_grad()) {
t.register_hook([op_name, is_warn](const at::Tensor& grad) {
reportAutogradNotImplemented(op_name, is_warn);
t.register_hook([op_name](const at::Tensor& grad) {
warnAutogradNotImplemented(op_name);
});
// If history is rebased, then we will attempt to warn
// on the view's base. This will catch most cases (because
@ -208,9 +188,8 @@ static void basicAutogradNotImplementedFallbackImpl(
const auto& base = t._base();
if (base.requires_grad()) {
// Can only register_hook on tensors that require grad.
base.register_hook(
[op_name, is_warn](const at::TensorBase& grad) {
reportAutogradNotImplemented(op_name, is_warn);
base.register_hook([op_name](const at::TensorBase& grad) {
warnAutogradNotImplemented(op_name);
});
}
}
@ -218,9 +197,9 @@ static void basicAutogradNotImplementedFallbackImpl(
}
// If the post-autograd implementation returns any Tensors that
// don't require grad, then we install the NotImplementedBackward
// grad_fn. This grad_fn warns in backward and returns undefined
// tensor gradients.
// don't require grad, then we install the WarnNotImplemented grad_fn.
// This grad_fn warns in backward and returns undefined tensor
// gradients.
//
// NOTE [autograd fallback and in-place operations]
// If the schema says the output is mutable, and the output