[functorch] Added some make_fx+vjp/jac/vmap tests

This commit is contained in:
Horace He 2021-07-17 18:35:37 -07:00 committed by Jon Janzen
parent 8e62e271be
commit 6d39fa335b

View File

@ -48,6 +48,35 @@ class TestPythonKey(TestCase):
new_inp = torch.randn(3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_vmap(self, device):
def f(x):
return torch.sin(x)
inp = torch.randn(5, 3)
f = vmap(f)
fx_f = make_fx(f)(inp)
new_inp = torch.randn(5, 3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_jacrev(self, device):
def f(x):
return x.sin().sum()
inp = torch.randn(3)
f = jacrev(jacrev(f))
fx_f = make_fx(f)(inp)
new_inp = torch.randn(3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_jvp(self, device):
def f(x):
return torch.sin(x).sum()
primals = torch.randn(3)
_, vjp_fn = vjp(f, primals)
cotangent = torch.randn(())
fx_f = make_fx(vjp_fn)(cotangent, True, True)
new_cotangent = torch.randn(())
self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
def test_nnc_jit(self, device):
def f(x):
return torch.sin(x)
@ -98,6 +127,7 @@ class TestPythonKey(TestCase):
only_for = ("cpu")
instantiate_device_type_tests(
TestPythonKey,