pytorch/functorch/test/test_pythonkey.py
Richard Zou b29e666ade [functorch] [BC-breaking] Update make_functional* (pytorch/functorch#52)
Updates make_functional to use the new improved variants. The new
variants are superior in every way so we're replacing the previous
variants with this.

If someone wants the older variants, they can be found at:
- make_functional_with_buffers_deprecated_v1
- make_functional_deprecated_v1
2022-07-21 13:40:55 -07:00

101 lines
2.5 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import torch.nn as nn
import torch.nn.functional as F
import unittest
import functools
import itertools
import warnings
import math
from typing import Callable, Type
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
skipCUDAIfNoMagma, onlyOnCPUAndCUDA, onlyCPU
import types
from functools import partial
import functorch
from functorch import (
grad, vjp, vmap, jacrev, grad_and_value,
make_functional_deprecated_v1, make_functional_with_buffers_deprecated_v1, make_fx, nnc_jit
)
# NB: numpy is a testing dependency!
import numpy as np
class TestPythonKey(TestCase):
def test_make_fx(self, device):
def f(x):
return torch.sin(x)
inp = torch.randn(3)
fx_f = make_fx(f)(inp)
new_inp = torch.randn(3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_nnc_jit(self, device):
def f(x):
return torch.sin(x)
jit_f = nnc_jit(f)
inp = torch.randn(3)
self.assertEqual(jit_f(inp), f(inp))
def test_nnc_scalar(self, device):
def f(x):
return torch.sin(x)
jit_f = nnc_jit(f)
inp = torch.randn(())
self.assertEqual(jit_f(inp), f(inp))
def test_nnc_pytrees(self, device):
def f(x):
return [torch.sin(x[0])]
jit_f = nnc_jit(f)
inp = [torch.randn(3)]
self.assertEqual(jit_f(inp), f(inp))
def test_external_calls(self, device):
def f(a, b):
return torch.mv(a, b)
jit_f = nnc_jit(f)
inp = [torch.randn(3, 3), torch.randn(3)]
self.assertEqual(jit_f(*inp), f(*inp))
def test_nnc_passthrough(self, device):
def f(x, y):
return x + y, y
inp = (torch.randn(3), torch.randn(3))
jit_f = nnc_jit(f)
self.assertEqual(jit_f(*inp), f(*inp))
def f(x):
x['a'] = x['a'] * 2
return x
inp = ({'a': torch.randn(3), 'b': torch.randn(3)},)
jit_f = nnc_jit(f)
self.assertEqual(jit_f(*inp), f(*inp))
only_for = ("cpu")
instantiate_device_type_tests(
TestPythonKey,
globals(),
only_for=only_for,
)
if __name__ == '__main__':
run_tests()