Fix torch.autograd.backward inputs validation (#150975)

- Fixes #150883
- Fixes #70504

This is my first PR to pytorch, so please tell me if I'm forgetting anything.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150975
Approved by: https://github.com/soulitzer
This commit is contained in:
Valérian Rey 2025-04-17 02:11:10 +00:00 committed by PyTorch MergeBot
parent 6f9ffaa991
commit f5851efed9
2 changed files with 17 additions and 10 deletions

View File

@ -2507,6 +2507,12 @@ class TestAutograd(TestCase):
lambda: torch.autograd.backward(fn(), gradient, inputs=[]),
)
def test_backward_with_scalar_input(self):
x = torch.randn([], dtype=torch.double, requires_grad=True)
out = x**2
out.backward(inputs=x)
self.assertEqual(x.grad, 2 * x)
def test_backward_with_nonleaf_inputs(self):
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
x_nonleaf = x * 1

View File

@ -325,8 +325,16 @@ def backward(
"arguments both passed to `backward()`. Please only "
"use `grad_tensors`."
)
if inputs is not None and len(inputs) == 0:
raise RuntimeError("`inputs` argument to `backward()` cannot be empty.")
inputs_tuple: tuple[Union[torch.Tensor, graph.GradientEdge], ...]
if inputs is None:
inputs_tuple = ()
elif isinstance(inputs, (torch.Tensor, graph.GradientEdge)):
inputs_tuple = (inputs,)
else:
inputs_tuple = tuple(inputs)
if len(inputs_tuple) == 0:
raise RuntimeError("`inputs` argument to `backward()` cannot be empty.")
if is_tensor_like(tensors) or isinstance(tensors, graph.GradientEdge):
tensors = cast(
@ -334,13 +342,6 @@ def backward(
)
else:
tensors = tuple(tensors)
inputs = (
(inputs,)
if isinstance(inputs, (torch.Tensor, graph.GradientEdge))
else tuple(inputs)
if inputs is not None
else ()
)
grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
@ -355,7 +356,7 @@ def backward(
grad_tensors_,
retain_graph,
create_graph,
inputs,
inputs_tuple,
allow_unreachable=True,
accumulate_grad=True,
)