mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29334 As title Test Plan: Imported from OSS Differential Revision: D18358592 Pulled By: suo fbshipit-source-id: d7afbce52ddd008ae9c42aeda6be24e35086ef01
131 lines
4.5 KiB
Python
131 lines
4.5 KiB
Python
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from jit_utils import JitTestCase
|
|
|
|
if __name__ == '__main__':
|
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead.")
|
|
|
|
def canonical(graph):
|
|
return torch._C._jit_pass_canonicalize(graph).str(False)
|
|
|
|
class TestCustomOperators(JitTestCase):
|
|
|
|
def test_dynamic_op_registry(self):
|
|
from torch._ops import _OpNamespace
|
|
self.assertTrue(hasattr(torch, 'ops'))
|
|
|
|
if '_test' in torch.ops.__dict__:
|
|
torch.ops.__dict__.pop('_test')
|
|
|
|
# Don't use `hasattr()` because it will call `__getattr__`.
|
|
self.assertNotIn('_test', torch.ops.__dict__)
|
|
torch.ops._test
|
|
self.assertIn('_test', torch.ops.__dict__)
|
|
self.assertEqual(type(torch.ops._test), _OpNamespace)
|
|
|
|
self.assertNotIn('leaky_relu', torch.ops._test.__dict__)
|
|
op = torch.ops._test.leaky_relu
|
|
self.assertTrue(callable(op))
|
|
self.assertIn('leaky_relu', torch.ops._test.__dict__)
|
|
op2 = torch.ops._test.leaky_relu
|
|
self.assertEqual(op, op2)
|
|
|
|
def test_simply_calling_an_operator(self):
|
|
input = torch.randn(100)
|
|
output = torch.ops.aten.relu(input)
|
|
self.assertEqual(output, input.relu())
|
|
|
|
def test_default_arguments_are_used(self):
|
|
output = torch.ops._test.leaky_relu(torch.tensor([-1.0, 1.0]))
|
|
self.assertEqual(output, torch.tensor([-0.01, 1]))
|
|
|
|
def test_only_kwargs(self):
|
|
output = torch.ops._test.leaky_relu(self=torch.tensor(-1.0))
|
|
self.assertEqual(output, torch.tensor(-0.01))
|
|
|
|
def test_passing_too_many_args(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"aten::relu\(\) expected at most 1 argument\(s\) but received 2 argument\(s\)"
|
|
):
|
|
torch.ops.aten.relu(1, 2)
|
|
|
|
def test_passing_too_few_args(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"aten::relu\(\) is missing value for argument 'self'."
|
|
):
|
|
torch.ops.aten.relu()
|
|
|
|
def test_passing_one_positional_but_not_the_second(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"aten::transpose\(\) is missing value for argument 'dim0'."
|
|
):
|
|
torch.ops.aten.transpose(torch.ones(5, 5))
|
|
|
|
def test_passing_an_argument_both_as_positional_and_kwarg(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Argument 'self' specified both as positional and keyword argument"
|
|
):
|
|
torch.ops._test.leaky_relu(torch.ones(5), self=torch.ones(5))
|
|
|
|
def test_passing_unknown_kwargs(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Unknown keyword argument 'foo' for operator '_test::leaky_relu'"
|
|
):
|
|
torch.ops._test.leaky_relu(torch.ones(5), foo=torch.ones(5))
|
|
|
|
def test_passing_and_returning_lists(self):
|
|
# Replace with actual test once we support lists.
|
|
a, b = torch.rand(5), torch.rand(5)
|
|
output = torch.ops._test.cat([a, b])
|
|
output_ref = torch.cat([a, b])
|
|
self.assertEqual(output, output_ref)
|
|
|
|
def test_calling_scripted_custom_op(self):
|
|
@torch.jit.script
|
|
def func(x):
|
|
return torch.ops.aten.relu(x)
|
|
input = torch.ones(5, 5)
|
|
self.assertEqual(func(input), input.relu())
|
|
|
|
def test_calling_traced_custom_op(self):
|
|
input = torch.ones(5, 5)
|
|
func = torch.jit.trace(torch.ops.aten.relu, [input])
|
|
self.assertEqual(func(input), input.relu())
|
|
|
|
@unittest.skip("Need to figure out default dtype differences between fbcode and oss")
|
|
def test_script_graph_for_custom_ops_matches_traced_graph(self):
|
|
input = torch.ones(5, 5)
|
|
trace = torch.jit.trace(torch.ops.aten.relu, [input])
|
|
self.assertExpectedInline(canonical(trace.graph), '''\
|
|
graph(%0 : Float(5, 5)):
|
|
%1 : Float(5, 5) = aten::relu(%0)
|
|
return (%1)
|
|
''')
|
|
|
|
def test_script_graph_contains_custom_op(self):
|
|
@torch.jit.script
|
|
def func(x):
|
|
return torch.ops.aten.relu(x)
|
|
self.assertExpectedInline(canonical(func.graph), '''\
|
|
graph(%x.1 : Tensor):
|
|
%1 : Tensor = aten::relu(%x.1)
|
|
return (%1)
|
|
''')
|
|
|
|
def test_generic_list(self):
|
|
self.assertEqual(torch.ops._test.get_first([['hello']]), 'hello')
|