mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Improve error message when view of intermediate is returned from autograd.Function and marked dirty (#149543)
Fixes https://github.com/pytorch/pytorch/issues/149252 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149543 Approved by: https://github.com/zou3519 ghstack dependencies: #149220
This commit is contained in:
parent
7b218ca874
commit
a2bba53f87
|
|
@ -8626,6 +8626,26 @@ for shape in [(1,), ()]:
|
|||
self.assertTrue(out_dual is x_dual)
|
||||
self.assertTrue(out_tangent is x_tangent)
|
||||
|
||||
def test_custom_function_mark_output_view_of_intermediate(self):
|
||||
class Func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, inp):
|
||||
out = inp.clone().view_as(inp)
|
||||
ctx.mark_dirty(out)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO):
|
||||
pass
|
||||
|
||||
a = torch.tensor([1.0], requires_grad=True)
|
||||
a_clone = a.clone()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "received a tensor that was not an input."
|
||||
):
|
||||
Func.apply(a_clone)
|
||||
|
||||
def test_named_tensor_for_complex_views(self):
|
||||
names = ["batch", "height", "width", "complex"]
|
||||
z = torch.ones((2, 1, 2, 2), requires_grad=True)
|
||||
|
|
|
|||
|
|
@ -309,9 +309,13 @@ static optional_variable_list _process_backward_mode_ad(
|
|||
}
|
||||
// No need to mark as modified Tensors that are not inputs.
|
||||
if (!is_input) {
|
||||
TORCH_WARN(
|
||||
"Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there"
|
||||
" is no need to pass it to mark_dirty().");
|
||||
const char* mark_dirty_error_msg =
|
||||
"ctx.mark_dirty() received a tensor that was not an input. "
|
||||
"Only input Tensors that have been mutated should be passed to "
|
||||
"ctx.mark_dirty().";
|
||||
// We reach this path in the view of intermediate case
|
||||
TORCH_CHECK(!var.is_view(), mark_dirty_error_msg);
|
||||
TORCH_WARN(mark_dirty_error_msg);
|
||||
}
|
||||
// If the input is a view, the rebase will need to rewrite the graph and
|
||||
// this only works if we have a single output to this Function.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user