mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[functorch] Added some make_fx+vjp/jac/vmap tests
This commit is contained in:
parent
8e62e271be
commit
6d39fa335b
|
|
@ -48,6 +48,35 @@ class TestPythonKey(TestCase):
|
|||
new_inp = torch.randn(3)
|
||||
self.assertEqual(fx_f(new_inp), f(new_inp))
|
||||
|
||||
def test_make_fx_vmap(self, device):
|
||||
def f(x):
|
||||
return torch.sin(x)
|
||||
inp = torch.randn(5, 3)
|
||||
f = vmap(f)
|
||||
fx_f = make_fx(f)(inp)
|
||||
new_inp = torch.randn(5, 3)
|
||||
self.assertEqual(fx_f(new_inp), f(new_inp))
|
||||
|
||||
def test_make_fx_jacrev(self, device):
|
||||
def f(x):
|
||||
return x.sin().sum()
|
||||
inp = torch.randn(3)
|
||||
f = jacrev(jacrev(f))
|
||||
fx_f = make_fx(f)(inp)
|
||||
new_inp = torch.randn(3)
|
||||
self.assertEqual(fx_f(new_inp), f(new_inp))
|
||||
|
||||
def test_make_fx_jvp(self, device):
|
||||
def f(x):
|
||||
return torch.sin(x).sum()
|
||||
|
||||
primals = torch.randn(3)
|
||||
_, vjp_fn = vjp(f, primals)
|
||||
cotangent = torch.randn(())
|
||||
fx_f = make_fx(vjp_fn)(cotangent, True, True)
|
||||
new_cotangent = torch.randn(())
|
||||
self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
|
||||
|
||||
def test_nnc_jit(self, device):
|
||||
def f(x):
|
||||
return torch.sin(x)
|
||||
|
|
@ -98,6 +127,7 @@ class TestPythonKey(TestCase):
|
|||
|
||||
|
||||
|
||||
|
||||
only_for = ("cpu")
|
||||
instantiate_device_type_tests(
|
||||
TestPythonKey,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user