Improve error message when view of intermediate is returned from autograd.Function and marked dirty (#149543)

Fixes https://github.com/pytorch/pytorch/issues/149252
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149543
Approved by: https://github.com/zou3519
ghstack dependencies: #149220
This commit is contained in:
soulitzer 2025-03-19 13:15:50 -07:00 committed by PyTorch MergeBot
parent 7b218ca874
commit a2bba53f87
2 changed files with 27 additions and 3 deletions

View File

@ -8626,6 +8626,26 @@ for shape in [(1,), ()]:
self.assertTrue(out_dual is x_dual)
self.assertTrue(out_tangent is x_tangent)
def test_custom_function_mark_output_view_of_intermediate(self):
class Func(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
out = inp.clone().view_as(inp)
ctx.mark_dirty(out)
return out
@staticmethod
def backward(ctx, gO):
pass
a = torch.tensor([1.0], requires_grad=True)
a_clone = a.clone()
with self.assertRaisesRegex(
RuntimeError, "received a tensor that was not an input."
):
Func.apply(a_clone)
def test_named_tensor_for_complex_views(self):
names = ["batch", "height", "width", "complex"]
z = torch.ones((2, 1, 2, 2), requires_grad=True)

View File

@ -309,9 +309,13 @@ static optional_variable_list _process_backward_mode_ad(
}
// No need to mark as modified Tensors that are not inputs.
if (!is_input) {
TORCH_WARN(
"Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there"
" is no need to pass it to mark_dirty().");
const char* mark_dirty_error_msg =
"ctx.mark_dirty() received a tensor that was not an input. "
"Only input Tensors that have been mutated should be passed to "
"ctx.mark_dirty().";
// We reach this path in the view of intermediate case
TORCH_CHECK(!var.is_view(), mark_dirty_error_msg);
TORCH_WARN(mark_dirty_error_msg);
}
// If the input is a view, the rebase will need to rewrite the graph and
// this only works if we have a single output to this Function.