From 5bcfdae71da2f405a89bace6b09d616107cc3965 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 31 Oct 2025 13:44:05 +0000 Subject: [PATCH] Revert "Make PT2 compile backprop through custom op without autograd key a hard error (#166367)" This reverts commit 4acc66f1192ab7743abcc50383aefc5447447f9d. Reverted https://github.com/pytorch/pytorch/pull/166367 on behalf of https://github.com/atalman due to internal build failures ([comment](https://github.com/pytorch/pytorch/pull/166367#issuecomment-3473150269)) --- test/distributed/test_inductor_collectives.py | 11 +-- test/dynamo/test_misc.py | 11 --- test/dynamo/test_structured_trace.py | 6 -- test/test_autograd_fallback.py | 11 ++- torch/_functorch/aot_autograd.py | 4 - torch/_library/autograd.py | 11 --- .../autograd_not_implemented_fallback.cpp | 85 +++++++------------ 7 files changed, 48 insertions(+), 91 deletions(-) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 28f851b5fd6..62e5143d062 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -414,7 +414,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase): with _dynamo_dist_per_rank_init(self.rank, self.world_size): model = Model().to(self.device) - model.emb.weight.requires_grad = False model_compiled = torch.compile(model) inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device=self.device) out = model_compiled(inp, self.world_size, **self.get_world_trs()) @@ -1341,11 +1340,13 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): assert counter.op_count == 3 # It generates 2 getattr to unpack the array assert same(out, correct) - # This doesn't work in all cases, and now we properly loudly error. - # See: https://github.com/pytorch/pytorch/issues/151240 - # When differentiable funcols are implemented can revert. - @unittest.expectedFailure def test_backwards(self): + """ + It's probably not that common to need backwards support for collectives. + + However, I wanted to at least see if it was possible to support it as a design goal. + """ + def func(inp): ar = _functional_collectives.all_reduce(inp, "sum", "0") return ar diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 9883d093681..c47a26a7f6f 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9784,17 +9784,6 @@ def ___make_guard_fn(): def foo_impl(x, y): return torch.cat([x, y]) - def setup_context(ctx, inputs, output): - (x, _) = inputs - ctx.xs = x.shape[0] - - def foo_backward(ctx, grad): - return grad[: ctx.xs], grad[ctx.xs :] - - torch.library.register_autograd( - "mylib::foo", foo_backward, setup_context=setup_context - ) - @torch.compile(backend="aot_eager", fullgraph=True) def f(x, i): i0, i1 = i.tolist() diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index 4c64f99a53c..e1e2b228062 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -1254,8 +1254,6 @@ def forward(self, x_1: "f32[2][1]cpu"): torch._dynamo.reset() mod = SimpleModule().cuda() - for p in mod.parameters(): - p.requires_grad = False compiled = torch.compile(mod, backend="inductor") compiled(torch.randn(4, 4, device="cuda")) @@ -1323,8 +1321,6 @@ def forward(self, x_1: "f32[2][1]cpu"): torch._dynamo.reset() mod = MixedModule().cuda() - for p in mod.parameters(): - p.requires_grad = False compiled = torch.compile(mod, backend="inductor") compiled(torch.randn(4, 4, device="cuda")) @@ -1379,8 +1375,6 @@ def forward(self, x_1: "f32[2][1]cpu"): with self._setup_runtime_estimates_capture() as payload_buffer: torch._dynamo.reset() mod = Mixed().cuda() - for p in mod.parameters(): - p.requires_grad = False compiled = torch.compile(mod, backend="inductor") compiled(torch.randn(4, 4, device="cuda")) payload = payload_buffer.getvalue().strip() diff --git a/test/test_autograd_fallback.py b/test/test_autograd_fallback.py index 5748b5c4cca..d6252ac6f34 100644 --- a/test/test_autograd_fallback.py +++ b/test/test_autograd_fallback.py @@ -6,7 +6,6 @@ import warnings import numpy as np import torch -from torch._library.autograd import autograd_fallback_mode from torch.library import _scoped_library from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -16,6 +15,16 @@ from torch.testing._internal.common_utils import ( ) +@contextlib.contextmanager +def autograd_fallback_mode(mode): + prev = torch._C._get_autograd_fallback_mode() + try: + torch._C._set_autograd_fallback_mode(mode) + yield + finally: + torch._C._set_autograd_fallback_mode(prev) + + class TestAutogradFallback(TestCase): test_ns = "_test_autograd_fallback" diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index f54179d2186..f48cb04f08f 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -26,7 +26,6 @@ from torch._dynamo.utils import ( from torch._guards import detect_fake_mode from torch._inductor.cudagraph_utils import BoxedDeviceIndex from torch._inductor.utils import BoxedBool -from torch._library.autograd import autograd_fallback_mode from torch._subclasses import FakeTensor, FakeTensorMode from torch.export._tree_utils import reorder_kwargs from torch.fx.experimental.proxy_tensor import make_fx @@ -529,9 +528,6 @@ def create_aot_state( stack.enter_context( torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing() ) - # Make it an error to backprop through PT2 compliant ops that silently - # detach autograd - stack.enter_context(autograd_fallback_mode("error")) from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj diff --git a/torch/_library/autograd.py b/torch/_library/autograd.py index 125ed5b73d8..2707d07059e 100644 --- a/torch/_library/autograd.py +++ b/torch/_library/autograd.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -import contextlib import dataclasses from collections.abc import Callable from dataclasses import dataclass @@ -236,16 +235,6 @@ def not_list_of_optional_tensor(tree): return True -@contextlib.contextmanager -def autograd_fallback_mode(mode): - prev = _C._get_autograd_fallback_mode() - try: - _C._set_autograd_fallback_mode(mode) - yield - finally: - _C._set_autograd_fallback_mode(prev) - - flatten = _pytree.tree_flatten unflatten = _pytree.tree_unflatten spec_t = _pytree.TreeSpec diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp index 43204c3b234..9de461cc56a 100644 --- a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -50,6 +50,7 @@ AutogradFallbackMode kAutogradFallbackMode = AutogradFallbackMode::Warn; } // namespace void setAutogradFallbackMode(AutogradFallbackMode mode) { + TORCH_CHECK(mode != AutogradFallbackMode::Error, "NYI: mode='error'"); kAutogradFallbackMode = mode; } @@ -57,61 +58,41 @@ AutogradFallbackMode getAutogradFallbackMode() { return kAutogradFallbackMode; } -static void reportAutogradNotImplemented( - const std::string& op_name, - bool is_warn) { - if (is_warn) { - TORCH_WARN( - op_name, - ": an autograd kernel was not registered to the Autograd key(s) ", - "but we are trying to backprop through it. This may lead to silently incorrect behavior. ", - "This behavior is deprecated and will be removed in a future version of PyTorch. ", - "If your operator is differentiable, please ensure you have registered an " - "autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, " - "DispatchKey::CompositeImplicitAutograd). If your operator is not " - "differentiable, or to squash this warning and use the previous behavior, " - "please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd."); - } else { - TORCH_CHECK( - 0, - op_name, - ": an autograd kernel was not registered to the Autograd key(s) ", - "but we are trying to backprop through it. This can lead to silently incorrect behavior. ", - "If your operator is differentiable, please ensure you have registered an " - "autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, " - "). If your operator is not " - "differentiable and ensure NO gradients flow through this operator, " - "please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.") - } +static void warnAutogradNotImplemented(const std::string& op_name) { + TORCH_WARN( + op_name, + ": an autograd kernel was not registered to the Autograd key(s) ", + "but we are trying to backprop through it. This may lead to silently incorrect behavior. ", + "This behavior is deprecated and will be removed in a future version of PyTorch. ", + "If your operator is differentiable, please ensure you have registered an " + "autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, " + "DispatchKey::CompositeImplicitAutograd). If your operator is not " + "differentiable, or to squash this warning and use the previous behavior, " + "please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd."); } -struct NotImplementedBackward : public Node { - NotImplementedBackward( +struct WarnNotImplemented : public Node { + WarnNotImplemented( std::string op_name, size_t num_outputs, - bool is_warn, edge_list&& next_edges) : Node(std::move(next_edges)), op_name(std::move(op_name)), - num_outputs(num_outputs), - is_warn(is_warn) {} + num_outputs(num_outputs) {} - NotImplementedBackward(std::string op_name, size_t num_outputs, bool is_warn) - : op_name(std::move(op_name)), - num_outputs(num_outputs), - is_warn(is_warn) {} + WarnNotImplemented(std::string op_name, size_t num_outputs) + : op_name(std::move(op_name)), num_outputs(num_outputs) {} variable_list apply(variable_list&& inputs) override; std::string op_name; size_t num_outputs; - bool is_warn; }; // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) -auto NotImplementedBackward::apply(variable_list&& inputs) -> variable_list { +auto WarnNotImplemented::apply(variable_list&& inputs) -> variable_list { auto inputsLocal = std::move(inputs); - reportAutogradNotImplemented(op_name, is_warn); + warnAutogradNotImplemented(op_name); std::vector output(num_outputs); return output; } @@ -130,6 +111,8 @@ static void basicAutogradNotImplementedFallbackImpl( op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack); return; } + TORCH_INTERNAL_ASSERT( + getAutogradFallbackMode() == AutogradFallbackMode::Warn); bool any_input_requires_grad = false; _foreach_tensor( @@ -145,9 +128,7 @@ static void basicAutogradNotImplementedFallbackImpl( // by putting it after the requires_grad checks. any_input_requires_grad = any_input_requires_grad && GradMode::is_enabled(); - bool is_warn = getAutogradFallbackMode() == AutogradFallbackMode::Warn; - - std::shared_ptr grad_fn; + std::shared_ptr grad_fn; if (any_input_requires_grad) { // NB: It is standard to collect edges from all tensors // (see generated/VariableTypeEverything.cpp for examples) @@ -159,9 +140,8 @@ static void basicAutogradNotImplementedFallbackImpl( stack, stack_start, num_arguments); - grad_fn = std::shared_ptr( - new NotImplementedBackward( - op_name, all_tensors_on_stack.size(), is_warn), + grad_fn = std::shared_ptr( + new WarnNotImplemented(op_name, all_tensors_on_stack.size()), deleteNode); grad_fn->set_next_edges(collect_next_edges(all_tensors_on_stack)); } @@ -197,8 +177,8 @@ static void basicAutogradNotImplementedFallbackImpl( // >>> y = op(k) // >>> torch.autograd.grad(z.sum(), w) if (t.requires_grad()) { - t.register_hook([op_name, is_warn](const at::Tensor& grad) { - reportAutogradNotImplemented(op_name, is_warn); + t.register_hook([op_name](const at::Tensor& grad) { + warnAutogradNotImplemented(op_name); }); // If history is rebased, then we will attempt to warn // on the view's base. This will catch most cases (because @@ -208,19 +188,18 @@ static void basicAutogradNotImplementedFallbackImpl( const auto& base = t._base(); if (base.requires_grad()) { // Can only register_hook on tensors that require grad. - base.register_hook( - [op_name, is_warn](const at::TensorBase& grad) { - reportAutogradNotImplemented(op_name, is_warn); - }); + base.register_hook([op_name](const at::TensorBase& grad) { + warnAutogradNotImplemented(op_name); + }); } } return; } // If the post-autograd implementation returns any Tensors that - // don't require grad, then we install the NotImplementedBackward - // grad_fn. This grad_fn warns in backward and returns undefined - // tensor gradients. + // don't require grad, then we install the WarnNotImplemented grad_fn. + // This grad_fn warns in backward and returns undefined tensor + // gradients. // // NOTE [autograd fallback and in-place operations] // If the schema says the output is mutable, and the output