mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix torch.autograd.backward inputs validation (#150975)
- Fixes #150883 - Fixes #70504 This is my first PR to pytorch, so please tell me if I'm forgetting anything. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150975 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
6f9ffaa991
commit
f5851efed9
|
|
@ -2507,6 +2507,12 @@ class TestAutograd(TestCase):
|
|||
lambda: torch.autograd.backward(fn(), gradient, inputs=[]),
|
||||
)
|
||||
|
||||
def test_backward_with_scalar_input(self):
|
||||
x = torch.randn([], dtype=torch.double, requires_grad=True)
|
||||
out = x**2
|
||||
out.backward(inputs=x)
|
||||
self.assertEqual(x.grad, 2 * x)
|
||||
|
||||
def test_backward_with_nonleaf_inputs(self):
|
||||
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||
x_nonleaf = x * 1
|
||||
|
|
|
|||
|
|
@ -325,8 +325,16 @@ def backward(
|
|||
"arguments both passed to `backward()`. Please only "
|
||||
"use `grad_tensors`."
|
||||
)
|
||||
if inputs is not None and len(inputs) == 0:
|
||||
raise RuntimeError("`inputs` argument to `backward()` cannot be empty.")
|
||||
|
||||
inputs_tuple: tuple[Union[torch.Tensor, graph.GradientEdge], ...]
|
||||
if inputs is None:
|
||||
inputs_tuple = ()
|
||||
elif isinstance(inputs, (torch.Tensor, graph.GradientEdge)):
|
||||
inputs_tuple = (inputs,)
|
||||
else:
|
||||
inputs_tuple = tuple(inputs)
|
||||
if len(inputs_tuple) == 0:
|
||||
raise RuntimeError("`inputs` argument to `backward()` cannot be empty.")
|
||||
|
||||
if is_tensor_like(tensors) or isinstance(tensors, graph.GradientEdge):
|
||||
tensors = cast(
|
||||
|
|
@ -334,13 +342,6 @@ def backward(
|
|||
)
|
||||
else:
|
||||
tensors = tuple(tensors)
|
||||
inputs = (
|
||||
(inputs,)
|
||||
if isinstance(inputs, (torch.Tensor, graph.GradientEdge))
|
||||
else tuple(inputs)
|
||||
if inputs is not None
|
||||
else ()
|
||||
)
|
||||
|
||||
grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
|
||||
grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
|
||||
|
|
@ -355,7 +356,7 @@ def backward(
|
|||
grad_tensors_,
|
||||
retain_graph,
|
||||
create_graph,
|
||||
inputs,
|
||||
inputs_tuple,
|
||||
allow_unreachable=True,
|
||||
accumulate_grad=True,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user