[jit] move torchbind tests to separate file (#37473)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37473

Test Plan: Imported from OSS

Differential Revision: D21297541

Pulled By: suo

fbshipit-source-id: 65c48094b1f26fbbf251021957257ce04279922b
This commit is contained in:
Michael Suo 2020-05-13 17:34:34 -07:00 committed by Facebook GitHub Bot
parent 7d7d73655d
commit 2efa7e04c2
2 changed files with 242 additions and 232 deletions

241
test/jit/test_torchbind.py Normal file
View File

@ -0,0 +1,241 @@
import io
import os
import sys
import torch
from typing import Optional
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing import FileCheck
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestTorchbind(JitTestCase):
@skipIfRocm
def test_torchbind(self):
def test_equality(f, cmp_key):
obj1 = f()
obj2 = torch.jit.script(f)()
return (cmp_key(obj1), cmp_key(obj2))
def f():
val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment(1)
return val
test_equality(f, lambda x: x)
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment('foo')
def f():
ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
return ss.pop()
test_equality(f, lambda x: x)
def f():
ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
ss1.push(ss2.pop())
return ss1.pop() + ss2.pop()
test_equality(f, lambda x: x)
@skipIfRocm
def test_torchbind_take_as_arg(self):
global StackString # see [local resolution in python]
StackString = torch.classes._TorchScriptTesting._StackString
def foo(stackstring):
# type: (StackString)
stackstring.push("lel")
return stackstring
script_input = torch.classes._TorchScriptTesting._StackString([])
scripted = torch.jit.script(foo)
script_output = scripted(script_input)
self.assertEqual(script_output.pop(), "lel")
@skipIfRocm
def test_torchbind_return_instance(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
return ss
scripted = torch.jit.script(foo)
# Ensure we are creating the object and calling __init__
# rather than calling the __init__wrapper nonsense
fc = FileCheck().check('prim::CreateObject()')\
.check('prim::CallMethod[name="__init__"]')
fc.run(str(scripted.graph))
out = scripted()
self.assertEqual(out.pop(), "mom")
self.assertEqual(out.pop(), "hi")
@skipIfRocm
def test_torchbind_return_instance_from_method(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
clone = ss.clone()
ss.pop()
return ss, clone
scripted = torch.jit.script(foo)
out = scripted()
self.assertEqual(out[0].pop(), "hi")
self.assertEqual(out[1].pop(), "mom")
self.assertEqual(out[1].pop(), "hi")
@skipIfRocm
def test_torchbind_take_instance_as_method_arg(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
ss.merge(ss2)
return ss
scripted = torch.jit.script(foo)
out = scripted()
self.assertEqual(out.pop(), "hi")
self.assertEqual(out.pop(), "mom")
@skipIfRocm
def test_torchbind_return_tuple(self):
def f():
val = torch.classes._TorchScriptTesting._StackString(["3", "5"])
return val.return_a_tuple()
scripted = torch.jit.script(f)
tup = scripted()
self.assertEqual(tup, (1337.0, 123))
@skipIfRocm
def test_torchbind_save_load(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
ss.merge(ss2)
return ss
scripted = torch.jit.script(foo)
self.getExportImportCopy(scripted)
@skipIfRocm
def test_torchbind_lambda_method(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
return ss.top()
scripted = torch.jit.script(foo)
self.assertEqual(scripted(), "mom")
@skipIfRocm
def test_torchbind_class_attribute(self):
class FooBar1234(torch.nn.Module):
def __init__(self):
super(FooBar1234, self).__init__()
self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
def forward(self):
return self.f.top()
inst = FooBar1234()
scripted = torch.jit.script(inst)
eic = self.getExportImportCopy(scripted)
assert eic() == "deserialized"
for expected in ["deserialized", "was", "i"]:
assert eic.f.pop() == expected
@skipIfRocm
def test_torchbind_getstate(self):
class FooBar4321(torch.nn.Module):
def __init__(self):
super(FooBar4321, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
def forward(self):
return self.f.top()
inst = FooBar4321()
scripted = torch.jit.script(inst)
eic = self.getExportImportCopy(scripted)
# NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to
# return {1, 3, 3, 7}. I tried to make this actually depend on the
# values at instantiation in the test with some transformation, but
# because it seems we serialize/deserialize multiple times, that
# transformation isn't as you would it expect it to be.
assert eic() == 7
for expected in [7, 3, 3, 1]:
assert eic.f.pop() == expected
@skipIfRocm
def test_torchbind_tracing(self):
class TryTracing(torch.nn.Module):
def __init__(self):
super(TryTracing, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
def forward(self):
return torch.ops._TorchScriptTesting.take_an_instance(self.f)
traced = torch.jit.trace(TryTracing(), ())
self.assertEqual(torch.zeros(4, 4), traced())
@skipIfRocm
def test_torchbind_tracing_nested(self):
class TryTracingNest(torch.nn.Module):
def __init__(self):
super(TryTracingNest, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
class TryTracing123(torch.nn.Module):
def __init__(self):
super(TryTracing123, self).__init__()
self.nest = TryTracingNest()
def forward(self):
return torch.ops._TorchScriptTesting.take_an_instance(self.nest.f)
traced = torch.jit.trace(TryTracing123(), ())
self.assertEqual(torch.zeros(4, 4), traced())
@skipIfRocm
def test_torchbind_pickle_serialization(self):
nt = torch.classes._TorchScriptTesting._PickleTester([3, 4])
b = io.BytesIO()
torch.save(nt, b)
b.seek(0)
nt_loaded = torch.load(b)
for exp in [7, 3, 3, 1]:
self.assertEqual(nt_loaded.pop(), exp)
@skipIfRocm
def test_torchbind_instantiate_missing_class(self):
with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class \'foo.IDontExist\', but it does not exist!'):
torch.classes.foo.IDontExist(3, 4, 5)
@skipIfRocm
def test_torchbind_optional_explicit_attr(self):
class TorchBindOptionalExplicitAttr(torch.nn.Module):
foo : Optional[torch.classes._TorchScriptTesting._StackString]
def __init__(self):
super().__init__()
self.foo = torch.classes._TorchScriptTesting._StackString(["test"])
def forward(self) -> str:
foo_obj = self.foo
if foo_obj is not None:
return foo_obj.pop()
else:
return '<None>'
mod = TorchBindOptionalExplicitAttr()
scripted = torch.jit.script(mod)

View File

@ -23,6 +23,7 @@ from jit.test_freezing import TestFreezing # noqa: F401
from jit.test_save_load import TestSaveLoad # noqa: F401
from jit.test_python_ir import TestPythonIr # noqa: F401
from jit.test_functional_blocks import TestFunctionalBlocks # noqa: F401
from jit.test_torchbind import TestTorchbind # noqa: F401
# Torch
from torch import Tensor
@ -4579,114 +4580,6 @@ def foo(x):
self.assertEqual(7, w(3))
self.assertFalse("training" in w.state_dict())
@skipIfRocm
def test_torchbind(self):
def test_equality(f, cmp_key):
obj1 = f()
obj2 = torch.jit.script(f)()
return (cmp_key(obj1), cmp_key(obj2))
def f():
val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment(1)
return val
test_equality(f, lambda x: x)
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment('foo')
def f():
ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
return ss.pop()
test_equality(f, lambda x: x)
def f():
ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
ss1.push(ss2.pop())
return ss1.pop() + ss2.pop()
test_equality(f, lambda x: x)
@skipIfRocm
def test_torchbind_take_as_arg(self):
global StackString # see [local resolution in python]
StackString = torch.classes._TorchScriptTesting._StackString
def foo(stackstring):
# type: (StackString)
stackstring.push("lel")
return stackstring
script_input = torch.classes._TorchScriptTesting._StackString([])
scripted = torch.jit.script(foo)
script_output = scripted(script_input)
self.assertEqual(script_output.pop(), "lel")
@skipIfRocm
def test_torchbind_return_instance(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
return ss
scripted = torch.jit.script(foo)
# Ensure we are creating the object and calling __init__
# rather than calling the __init__wrapper nonsense
fc = FileCheck().check('prim::CreateObject()')\
.check('prim::CallMethod[name="__init__"]')
fc.run(str(scripted.graph))
out = scripted()
self.assertEqual(out.pop(), "mom")
self.assertEqual(out.pop(), "hi")
@skipIfRocm
def test_torchbind_return_instance_from_method(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
clone = ss.clone()
ss.pop()
return ss, clone
scripted = torch.jit.script(foo)
out = scripted()
self.assertEqual(out[0].pop(), "hi")
self.assertEqual(out[1].pop(), "mom")
self.assertEqual(out[1].pop(), "hi")
@skipIfRocm
def test_torchbind_take_instance_as_method_arg(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
ss.merge(ss2)
return ss
scripted = torch.jit.script(foo)
out = scripted()
self.assertEqual(out.pop(), "hi")
self.assertEqual(out.pop(), "mom")
@skipIfRocm
def test_torchbind_return_tuple(self):
def f():
val = torch.classes._TorchScriptTesting._StackString(["3", "5"])
return val.return_a_tuple()
scripted = torch.jit.script(f)
tup = scripted()
self.assertEqual(tup, (1337.0, 123))
@skipIfRocm
def test_torchbind_save_load(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
ss.merge(ss2)
return ss
scripted = torch.jit.script(foo)
self.getExportImportCopy(scripted)
def test_class_as_attribute(self):
@torch.jit.script
class Foo321(object):
@ -4706,124 +4599,6 @@ def foo(x):
x = torch.rand(3, 4)
self.assertEqual(scripted(x), eic(x))
@skipIfRocm
def test_torchbind_lambda_method(self):
def foo():
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
return ss.top()
scripted = torch.jit.script(foo)
self.assertEqual(scripted(), "mom")
@skipIfRocm
def test_torchbind_class_attribute(self):
class FooBar1234(torch.nn.Module):
def __init__(self):
super(FooBar1234, self).__init__()
self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
def forward(self):
return self.f.top()
inst = FooBar1234()
scripted = torch.jit.script(inst)
eic = self.getExportImportCopy(scripted)
assert eic() == "deserialized"
for expected in ["deserialized", "was", "i"]:
assert eic.f.pop() == expected
@skipIfRocm
def test_torchbind_getstate(self):
class FooBar4321(torch.nn.Module):
def __init__(self):
super(FooBar4321, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
def forward(self):
return self.f.top()
inst = FooBar4321()
scripted = torch.jit.script(inst)
eic = self.getExportImportCopy(scripted)
# NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to
# return {1, 3, 3, 7}. I tried to make this actually depend on the
# values at instantiation in the test with some transformation, but
# because it seems we serialize/deserialize multiple times, that
# transformation isn't as you would it expect it to be.
assert eic() == 7
for expected in [7, 3, 3, 1]:
assert eic.f.pop() == expected
@skipIfRocm
def test_torchbind_tracing(self):
class TryTracing(torch.nn.Module):
def __init__(self):
super(TryTracing, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
def forward(self):
return torch.ops._TorchScriptTesting.take_an_instance(self.f)
traced = torch.jit.trace(TryTracing(), ())
self.assertEqual(torch.zeros(4, 4), traced())
@skipIfRocm
def test_torchbind_tracing_nested(self):
class TryTracingNest(torch.nn.Module):
def __init__(self):
super(TryTracingNest, self).__init__()
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
class TryTracing123(torch.nn.Module):
def __init__(self):
super(TryTracing123, self).__init__()
self.nest = TryTracingNest()
def forward(self):
return torch.ops._TorchScriptTesting.take_an_instance(self.nest.f)
traced = torch.jit.trace(TryTracing123(), ())
self.assertEqual(torch.zeros(4, 4), traced())
@skipIfRocm
def test_torchbind_pickle_serialization(self):
nt = torch.classes._TorchScriptTesting._PickleTester([3, 4])
b = io.BytesIO()
torch.save(nt, b)
b.seek(0)
nt_loaded = torch.load(b)
for exp in [7, 3, 3, 1]:
self.assertEqual(nt_loaded.pop(), exp)
@skipIfRocm
def test_torchbind_instantiate_missing_class(self):
with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class \'foo.IDontExist\', but it does not exist!'):
torch.classes.foo.IDontExist(3, 4, 5)
@skipIfRocm
def test_torchbind_optional_explicit_attr(self):
class TorchBindOptionalExplicitAttr(torch.nn.Module):
foo : Optional[torch.classes._TorchScriptTesting._StackString]
def __init__(self):
super().__init__()
self.foo = torch.classes._TorchScriptTesting._StackString(["test"])
def forward(self) -> str:
foo_obj = self.foo
if foo_obj is not None:
return foo_obj.pop()
else:
return '<None>'
mod = TorchBindOptionalExplicitAttr()
scripted = torch.jit.script(mod)
@skipIfRocm
def test_torchbind_str(self):
foo = torch.classes._TorchScriptTesting._StackString(["foo", "bar", "baz"])
self.assertEqual(str(foo), "[foo, bar, baz]")
def test_module_str(self):
class Foo(torch.nn.Module):
def forward(self, x):
@ -4832,12 +4607,6 @@ def foo(x):
f = torch.jit.script(Foo())
self.assertEqual('ScriptObject', str(f._c))
@skipIfRocm
def test_torchbind_magic_unimplemented(self):
foo = torch.classes._TorchScriptTesting._StackString(["foo", "bar", "baz"])
with self.assertRaises(NotImplementedError):
foo[3]
def _test_lower_graph_impl(self, model, data):
model.qconfig = torch.quantization.default_qconfig
model = torch.quantization.prepare(model)