pytorch/test/jit/test_python_bindings.py
Thomas Viehmann ac0a3cc5fd Merge CompilationUnit from torch._C and torch.jit (#50614)
Summary:
This simplifies our handling and allows passing CompilationUnits from Python to C++ defined functions via PyBind easily.

Discussed on Slack with SplitInfinity

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50614

Reviewed By: anjali411

Differential Revision: D25938005

Pulled By: SplitInfinity

fbshipit-source-id: 94aadf0c063ddfef7ca9ea17bfa998d8e7b367ad
2021-01-25 11:06:40 -08:00

38 lines
1.1 KiB
Python

import torch
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TestPythonBindings\n\n"
"instead."
)
class TestPythonBindings(JitTestCase):
def test_cu_get_functions(self):
@torch.jit.script
def test_get_python_cu_fn(x: torch.Tensor):
return 2 * x
cu = torch.jit._state._python_cu
self.assertTrue(
"test_get_python_cu_fn" in (str(fn.name) for fn in cu.get_functions())
)
def test_cu_create_function(self):
@torch.jit.script
def fn(x: torch.Tensor):
return 2 * x
cu = torch._C.CompilationUnit()
cu.create_function("test_fn", fn.graph)
inp = torch.randn(5)
self.assertEqual(inp * 2, cu.find_function("test_fn")(inp))
self.assertEqual(cu.find_function("doesnt_exist"), None)
self.assertEqual(inp * 2, cu.test_fn(inp))
with self.assertRaises(AttributeError):
cu.doesnt_exist(inp)