mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
92 lines
2.3 KiB
Python
92 lines
2.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import unittest
|
|
import functools
|
|
import itertools
|
|
import warnings
|
|
import math
|
|
from typing import Callable, Type
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
|
|
skipCUDAIfNoMagma, onlyOnCPUAndCUDA, onlyCPU
|
|
import types
|
|
from functools import partial
|
|
|
|
import functorch
|
|
from functorch import (
|
|
grad, vjp, vmap, jacrev, grad_and_value,
|
|
make_functional, make_functional_with_buffers, make_fx, nnc_jit
|
|
)
|
|
|
|
# NB: numpy is a testing dependency!
|
|
import numpy as np
|
|
|
|
class TestPythonKey(TestCase):
|
|
def test_make_fx(self, device):
|
|
def f(x):
|
|
return torch.sin(x)
|
|
inp = torch.randn(3)
|
|
fx_f = make_fx(f)(inp)
|
|
|
|
new_inp = torch.randn(3)
|
|
self.assertEqual(fx_f(new_inp), f(new_inp))
|
|
|
|
def test_nnc_jit(self, device):
|
|
def f(x):
|
|
return torch.sin(x)
|
|
|
|
jit_f = nnc_jit(f)
|
|
|
|
inp = torch.randn(3)
|
|
self.assertEqual(jit_f(inp), f(inp))
|
|
|
|
def test_nnc_pytrees(self, device):
|
|
def f(x):
|
|
return [torch.sin(x[0])]
|
|
|
|
jit_f = nnc_jit(f)
|
|
|
|
inp = [torch.randn(3)]
|
|
self.assertEqual(jit_f(inp), f(inp))
|
|
|
|
def test_external_calls(self, device):
|
|
def f(a, b):
|
|
return torch.mv(a, b)
|
|
jit_f = nnc_jit(f)
|
|
inp = [torch.randn(3, 3), torch.randn(3)]
|
|
self.assertEqual(jit_f(*inp), f(*inp))
|
|
|
|
def test_nnc_passthrough(self, device):
|
|
def f(x, y):
|
|
return x + y, y
|
|
inp = (torch.randn(3), torch.randn(3))
|
|
jit_f = nnc_jit(f)
|
|
self.assertEqual(jit_f(*inp), f(*inp))
|
|
|
|
def f(x):
|
|
x['a'] = x['a'] * 2
|
|
return x
|
|
inp = ({'a': torch.randn(3), 'b': torch.randn(3)},)
|
|
jit_f = nnc_jit(f)
|
|
self.assertEqual(jit_f(*inp), f(*inp))
|
|
|
|
|
|
|
|
only_for = ("cpu")
|
|
instantiate_device_type_tests(
|
|
TestPythonKey,
|
|
globals(),
|
|
only_for=only_for,
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|