mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Taken from https://github.com/pytorch/pytorch/pull/60516 Pull Request resolved: https://github.com/pytorch/pytorch/pull/75262 Approved by: https://github.com/Krovatkin
52 lines
1.6 KiB
Python
52 lines
1.6 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import torch
|
|
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
from typing import List
|
|
|
|
class TestAutodiffJit(JitTestCase):
|
|
def test_undefined_tensor_lists(self):
|
|
def fn(tensor_list: List[torch.Tensor], add_tensor):
|
|
cat = torch.cat(tensor_list, dim=1)
|
|
r = torch.sin(cat + add_tensor)
|
|
return r
|
|
|
|
fn_s = torch.jit.script(fn)
|
|
|
|
a = torch.rand((3, 6), requires_grad=True)
|
|
b = torch.rand((3, 10), requires_grad=True)
|
|
x = [a, b]
|
|
y = torch.rand((3, 16), requires_grad=True)
|
|
|
|
ret = fn_s(x, y)
|
|
ret.sum().backward()
|
|
ret = fn_s(x, y)
|
|
ret.sum().backward()
|
|
|
|
ret = fn_s(x, y)
|
|
s = ret.sum()
|
|
|
|
# backward_fn expects 2 inputs: (grad_output, current_grad_r)
|
|
# current_grad_r is provided because we need to add this contribution
|
|
# to grad_r when we return it.
|
|
backward_fn = s.grad_fn.next_functions[0][0]
|
|
|
|
# check behavior with defined tensor
|
|
grad_out = torch.rand((3, 16))
|
|
grad_inputs = backward_fn(grad_out, None)
|
|
|
|
# expect 3 tensors: grad_y, grad_a, grad_b
|
|
self.assertEqual(3, len(grad_inputs))
|
|
for x in grad_inputs:
|
|
self.assertTrue(isinstance(x, torch.Tensor))
|
|
|
|
# now test with undefined grad_out
|
|
grad_inputs = backward_fn(None, None)
|
|
|
|
# expect all of them to be None
|
|
self.assertEqual(3, len(grad_inputs))
|
|
for x in grad_inputs:
|
|
if x is not None:
|
|
self.assertEqual(0, torch.max(torch.abs(x)).item())
|