mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Improve error message when view of intermediate is returned from autograd.Function and marked dirty (#149543)
Fixes https://github.com/pytorch/pytorch/issues/149252 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149543 Approved by: https://github.com/zou3519 ghstack dependencies: #149220
This commit is contained in:
parent
7b218ca874
commit
a2bba53f87
|
|
@ -8626,6 +8626,26 @@ for shape in [(1,), ()]:
|
||||||
self.assertTrue(out_dual is x_dual)
|
self.assertTrue(out_dual is x_dual)
|
||||||
self.assertTrue(out_tangent is x_tangent)
|
self.assertTrue(out_tangent is x_tangent)
|
||||||
|
|
||||||
|
def test_custom_function_mark_output_view_of_intermediate(self):
|
||||||
|
class Func(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, inp):
|
||||||
|
out = inp.clone().view_as(inp)
|
||||||
|
ctx.mark_dirty(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, gO):
|
||||||
|
pass
|
||||||
|
|
||||||
|
a = torch.tensor([1.0], requires_grad=True)
|
||||||
|
a_clone = a.clone()
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "received a tensor that was not an input."
|
||||||
|
):
|
||||||
|
Func.apply(a_clone)
|
||||||
|
|
||||||
def test_named_tensor_for_complex_views(self):
|
def test_named_tensor_for_complex_views(self):
|
||||||
names = ["batch", "height", "width", "complex"]
|
names = ["batch", "height", "width", "complex"]
|
||||||
z = torch.ones((2, 1, 2, 2), requires_grad=True)
|
z = torch.ones((2, 1, 2, 2), requires_grad=True)
|
||||||
|
|
|
||||||
|
|
@ -309,9 +309,13 @@ static optional_variable_list _process_backward_mode_ad(
|
||||||
}
|
}
|
||||||
// No need to mark as modified Tensors that are not inputs.
|
// No need to mark as modified Tensors that are not inputs.
|
||||||
if (!is_input) {
|
if (!is_input) {
|
||||||
TORCH_WARN(
|
const char* mark_dirty_error_msg =
|
||||||
"Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there"
|
"ctx.mark_dirty() received a tensor that was not an input. "
|
||||||
" is no need to pass it to mark_dirty().");
|
"Only input Tensors that have been mutated should be passed to "
|
||||||
|
"ctx.mark_dirty().";
|
||||||
|
// We reach this path in the view of intermediate case
|
||||||
|
TORCH_CHECK(!var.is_view(), mark_dirty_error_msg);
|
||||||
|
TORCH_WARN(mark_dirty_error_msg);
|
||||||
}
|
}
|
||||||
// If the input is a view, the rebase will need to rewrite the graph and
|
// If the input is a view, the rebase will need to rewrite the graph and
|
||||||
// this only works if we have a single output to this Function.
|
// this only works if we have a single output to this Function.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user