mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Updates make_functional to use the new improved variants. The new variants are superior in every way so we're replacing the previous variants with this. If someone wants the older variants, they can be found at: - make_functional_with_buffers_deprecated_v1 - make_functional_deprecated_v1
101 lines
2.5 KiB
Python
101 lines
2.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import unittest
|
|
import functools
|
|
import itertools
|
|
import warnings
|
|
import math
|
|
from typing import Callable, Type
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
|
|
skipCUDAIfNoMagma, onlyOnCPUAndCUDA, onlyCPU
|
|
import types
|
|
from functools import partial
|
|
|
|
import functorch
|
|
from functorch import (
|
|
grad, vjp, vmap, jacrev, grad_and_value,
|
|
make_functional_deprecated_v1, make_functional_with_buffers_deprecated_v1, make_fx, nnc_jit
|
|
)
|
|
|
|
# NB: numpy is a testing dependency!
|
|
import numpy as np
|
|
|
|
class TestPythonKey(TestCase):
|
|
def test_make_fx(self, device):
|
|
def f(x):
|
|
return torch.sin(x)
|
|
inp = torch.randn(3)
|
|
fx_f = make_fx(f)(inp)
|
|
|
|
new_inp = torch.randn(3)
|
|
self.assertEqual(fx_f(new_inp), f(new_inp))
|
|
|
|
def test_nnc_jit(self, device):
|
|
def f(x):
|
|
return torch.sin(x)
|
|
|
|
jit_f = nnc_jit(f)
|
|
|
|
inp = torch.randn(3)
|
|
self.assertEqual(jit_f(inp), f(inp))
|
|
|
|
def test_nnc_scalar(self, device):
|
|
def f(x):
|
|
return torch.sin(x)
|
|
|
|
jit_f = nnc_jit(f)
|
|
|
|
inp = torch.randn(())
|
|
self.assertEqual(jit_f(inp), f(inp))
|
|
|
|
def test_nnc_pytrees(self, device):
|
|
def f(x):
|
|
return [torch.sin(x[0])]
|
|
|
|
jit_f = nnc_jit(f)
|
|
|
|
inp = [torch.randn(3)]
|
|
self.assertEqual(jit_f(inp), f(inp))
|
|
|
|
def test_external_calls(self, device):
|
|
def f(a, b):
|
|
return torch.mv(a, b)
|
|
jit_f = nnc_jit(f)
|
|
inp = [torch.randn(3, 3), torch.randn(3)]
|
|
self.assertEqual(jit_f(*inp), f(*inp))
|
|
|
|
def test_nnc_passthrough(self, device):
|
|
def f(x, y):
|
|
return x + y, y
|
|
inp = (torch.randn(3), torch.randn(3))
|
|
jit_f = nnc_jit(f)
|
|
self.assertEqual(jit_f(*inp), f(*inp))
|
|
|
|
def f(x):
|
|
x['a'] = x['a'] * 2
|
|
return x
|
|
inp = ({'a': torch.randn(3), 'b': torch.randn(3)},)
|
|
jit_f = nnc_jit(f)
|
|
self.assertEqual(jit_f(*inp), f(*inp))
|
|
|
|
|
|
|
|
only_for = ("cpu")
|
|
instantiate_device_type_tests(
|
|
TestPythonKey,
|
|
globals(),
|
|
only_for=only_for,
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|