mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[jit] move torchbind tests to separate file (#37473)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37473 Test Plan: Imported from OSS Differential Revision: D21297541 Pulled By: suo fbshipit-source-id: 65c48094b1f26fbbf251021957257ce04279922b
This commit is contained in:
parent
7d7d73655d
commit
2efa7e04c2
241
test/jit/test_torchbind.py
Normal file
241
test/jit/test_torchbind.py
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
import io
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.common_utils import skipIfRocm
|
||||
from torch.testing import FileCheck
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
class TestTorchbind(JitTestCase):
|
||||
@skipIfRocm
|
||||
def test_torchbind(self):
|
||||
def test_equality(f, cmp_key):
|
||||
obj1 = f()
|
||||
obj2 = torch.jit.script(f)()
|
||||
return (cmp_key(obj1), cmp_key(obj2))
|
||||
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
||||
val.increment(1)
|
||||
return val
|
||||
test_equality(f, lambda x: x)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
|
||||
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
||||
val.increment('foo')
|
||||
|
||||
def f():
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
|
||||
return ss.pop()
|
||||
test_equality(f, lambda x: x)
|
||||
|
||||
def f():
|
||||
ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
|
||||
ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
|
||||
ss1.push(ss2.pop())
|
||||
return ss1.pop() + ss2.pop()
|
||||
test_equality(f, lambda x: x)
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_take_as_arg(self):
|
||||
global StackString # see [local resolution in python]
|
||||
StackString = torch.classes._TorchScriptTesting._StackString
|
||||
|
||||
def foo(stackstring):
|
||||
# type: (StackString)
|
||||
stackstring.push("lel")
|
||||
return stackstring
|
||||
|
||||
script_input = torch.classes._TorchScriptTesting._StackString([])
|
||||
scripted = torch.jit.script(foo)
|
||||
script_output = scripted(script_input)
|
||||
self.assertEqual(script_output.pop(), "lel")
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_return_instance(self):
|
||||
def foo():
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
|
||||
return ss
|
||||
|
||||
scripted = torch.jit.script(foo)
|
||||
# Ensure we are creating the object and calling __init__
|
||||
# rather than calling the __init__wrapper nonsense
|
||||
fc = FileCheck().check('prim::CreateObject()')\
|
||||
.check('prim::CallMethod[name="__init__"]')
|
||||
fc.run(str(scripted.graph))
|
||||
out = scripted()
|
||||
self.assertEqual(out.pop(), "mom")
|
||||
self.assertEqual(out.pop(), "hi")
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_return_instance_from_method(self):
|
||||
def foo():
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
|
||||
clone = ss.clone()
|
||||
ss.pop()
|
||||
return ss, clone
|
||||
|
||||
scripted = torch.jit.script(foo)
|
||||
out = scripted()
|
||||
self.assertEqual(out[0].pop(), "hi")
|
||||
self.assertEqual(out[1].pop(), "mom")
|
||||
self.assertEqual(out[1].pop(), "hi")
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_take_instance_as_method_arg(self):
|
||||
def foo():
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
|
||||
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
|
||||
ss.merge(ss2)
|
||||
return ss
|
||||
|
||||
scripted = torch.jit.script(foo)
|
||||
out = scripted()
|
||||
self.assertEqual(out.pop(), "hi")
|
||||
self.assertEqual(out.pop(), "mom")
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_return_tuple(self):
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting._StackString(["3", "5"])
|
||||
return val.return_a_tuple()
|
||||
|
||||
scripted = torch.jit.script(f)
|
||||
tup = scripted()
|
||||
self.assertEqual(tup, (1337.0, 123))
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_save_load(self):
|
||||
def foo():
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
|
||||
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
|
||||
ss.merge(ss2)
|
||||
return ss
|
||||
|
||||
scripted = torch.jit.script(foo)
|
||||
self.getExportImportCopy(scripted)
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_lambda_method(self):
|
||||
def foo():
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
|
||||
return ss.top()
|
||||
|
||||
scripted = torch.jit.script(foo)
|
||||
self.assertEqual(scripted(), "mom")
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_class_attribute(self):
|
||||
class FooBar1234(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(FooBar1234, self).__init__()
|
||||
self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
|
||||
|
||||
def forward(self):
|
||||
return self.f.top()
|
||||
|
||||
inst = FooBar1234()
|
||||
scripted = torch.jit.script(inst)
|
||||
eic = self.getExportImportCopy(scripted)
|
||||
assert eic() == "deserialized"
|
||||
for expected in ["deserialized", "was", "i"]:
|
||||
assert eic.f.pop() == expected
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_getstate(self):
|
||||
class FooBar4321(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(FooBar4321, self).__init__()
|
||||
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
||||
|
||||
def forward(self):
|
||||
return self.f.top()
|
||||
|
||||
inst = FooBar4321()
|
||||
scripted = torch.jit.script(inst)
|
||||
eic = self.getExportImportCopy(scripted)
|
||||
# NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to
|
||||
# return {1, 3, 3, 7}. I tried to make this actually depend on the
|
||||
# values at instantiation in the test with some transformation, but
|
||||
# because it seems we serialize/deserialize multiple times, that
|
||||
# transformation isn't as you would it expect it to be.
|
||||
assert eic() == 7
|
||||
for expected in [7, 3, 3, 1]:
|
||||
assert eic.f.pop() == expected
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_tracing(self):
|
||||
class TryTracing(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TryTracing, self).__init__()
|
||||
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
||||
|
||||
def forward(self):
|
||||
return torch.ops._TorchScriptTesting.take_an_instance(self.f)
|
||||
|
||||
traced = torch.jit.trace(TryTracing(), ())
|
||||
self.assertEqual(torch.zeros(4, 4), traced())
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_tracing_nested(self):
|
||||
class TryTracingNest(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TryTracingNest, self).__init__()
|
||||
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
||||
|
||||
class TryTracing123(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TryTracing123, self).__init__()
|
||||
self.nest = TryTracingNest()
|
||||
|
||||
def forward(self):
|
||||
return torch.ops._TorchScriptTesting.take_an_instance(self.nest.f)
|
||||
|
||||
traced = torch.jit.trace(TryTracing123(), ())
|
||||
self.assertEqual(torch.zeros(4, 4), traced())
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_pickle_serialization(self):
|
||||
nt = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
||||
b = io.BytesIO()
|
||||
torch.save(nt, b)
|
||||
b.seek(0)
|
||||
nt_loaded = torch.load(b)
|
||||
for exp in [7, 3, 3, 1]:
|
||||
self.assertEqual(nt_loaded.pop(), exp)
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_instantiate_missing_class(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class \'foo.IDontExist\', but it does not exist!'):
|
||||
torch.classes.foo.IDontExist(3, 4, 5)
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_optional_explicit_attr(self):
|
||||
class TorchBindOptionalExplicitAttr(torch.nn.Module):
|
||||
foo : Optional[torch.classes._TorchScriptTesting._StackString]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.foo = torch.classes._TorchScriptTesting._StackString(["test"])
|
||||
|
||||
def forward(self) -> str:
|
||||
foo_obj = self.foo
|
||||
if foo_obj is not None:
|
||||
return foo_obj.pop()
|
||||
else:
|
||||
return '<None>'
|
||||
|
||||
mod = TorchBindOptionalExplicitAttr()
|
||||
scripted = torch.jit.script(mod)
|
||||
233
test/test_jit.py
233
test/test_jit.py
|
|
@ -23,6 +23,7 @@ from jit.test_freezing import TestFreezing # noqa: F401
|
|||
from jit.test_save_load import TestSaveLoad # noqa: F401
|
||||
from jit.test_python_ir import TestPythonIr # noqa: F401
|
||||
from jit.test_functional_blocks import TestFunctionalBlocks # noqa: F401
|
||||
from jit.test_torchbind import TestTorchbind # noqa: F401
|
||||
|
||||
# Torch
|
||||
from torch import Tensor
|
||||
|
|
@ -4579,114 +4580,6 @@ def foo(x):
|
|||
self.assertEqual(7, w(3))
|
||||
self.assertFalse("training" in w.state_dict())
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind(self):
|
||||
def test_equality(f, cmp_key):
|
||||
obj1 = f()
|
||||
obj2 = torch.jit.script(f)()
|
||||
return (cmp_key(obj1), cmp_key(obj2))
|
||||
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
||||
val.increment(1)
|
||||
return val
|
||||
test_equality(f, lambda x: x)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
|
||||
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
||||
val.increment('foo')
|
||||
|
||||
def f():
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
|
||||
return ss.pop()
|
||||
test_equality(f, lambda x: x)
|
||||
|
||||
def f():
|
||||
ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
|
||||
ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
|
||||
ss1.push(ss2.pop())
|
||||
return ss1.pop() + ss2.pop()
|
||||
test_equality(f, lambda x: x)
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_take_as_arg(self):
|
||||
global StackString # see [local resolution in python]
|
||||
StackString = torch.classes._TorchScriptTesting._StackString
|
||||
|
||||
def foo(stackstring):
|
||||
# type: (StackString)
|
||||
stackstring.push("lel")
|
||||
return stackstring
|
||||
|
||||
script_input = torch.classes._TorchScriptTesting._StackString([])
|
||||
scripted = torch.jit.script(foo)
|
||||
script_output = scripted(script_input)
|
||||
self.assertEqual(script_output.pop(), "lel")
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_return_instance(self):
|
||||
def foo():
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
|
||||
return ss
|
||||
|
||||
scripted = torch.jit.script(foo)
|
||||
# Ensure we are creating the object and calling __init__
|
||||
# rather than calling the __init__wrapper nonsense
|
||||
fc = FileCheck().check('prim::CreateObject()')\
|
||||
.check('prim::CallMethod[name="__init__"]')
|
||||
fc.run(str(scripted.graph))
|
||||
out = scripted()
|
||||
self.assertEqual(out.pop(), "mom")
|
||||
self.assertEqual(out.pop(), "hi")
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_return_instance_from_method(self):
|
||||
def foo():
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
|
||||
clone = ss.clone()
|
||||
ss.pop()
|
||||
return ss, clone
|
||||
|
||||
scripted = torch.jit.script(foo)
|
||||
out = scripted()
|
||||
self.assertEqual(out[0].pop(), "hi")
|
||||
self.assertEqual(out[1].pop(), "mom")
|
||||
self.assertEqual(out[1].pop(), "hi")
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_take_instance_as_method_arg(self):
|
||||
def foo():
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
|
||||
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
|
||||
ss.merge(ss2)
|
||||
return ss
|
||||
|
||||
scripted = torch.jit.script(foo)
|
||||
out = scripted()
|
||||
self.assertEqual(out.pop(), "hi")
|
||||
self.assertEqual(out.pop(), "mom")
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_return_tuple(self):
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting._StackString(["3", "5"])
|
||||
return val.return_a_tuple()
|
||||
|
||||
scripted = torch.jit.script(f)
|
||||
tup = scripted()
|
||||
self.assertEqual(tup, (1337.0, 123))
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_save_load(self):
|
||||
def foo():
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
|
||||
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
|
||||
ss.merge(ss2)
|
||||
return ss
|
||||
|
||||
scripted = torch.jit.script(foo)
|
||||
self.getExportImportCopy(scripted)
|
||||
|
||||
def test_class_as_attribute(self):
|
||||
@torch.jit.script
|
||||
class Foo321(object):
|
||||
|
|
@ -4706,124 +4599,6 @@ def foo(x):
|
|||
x = torch.rand(3, 4)
|
||||
self.assertEqual(scripted(x), eic(x))
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_lambda_method(self):
|
||||
def foo():
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
|
||||
return ss.top()
|
||||
|
||||
scripted = torch.jit.script(foo)
|
||||
self.assertEqual(scripted(), "mom")
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_class_attribute(self):
|
||||
class FooBar1234(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(FooBar1234, self).__init__()
|
||||
self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
|
||||
|
||||
def forward(self):
|
||||
return self.f.top()
|
||||
|
||||
inst = FooBar1234()
|
||||
scripted = torch.jit.script(inst)
|
||||
eic = self.getExportImportCopy(scripted)
|
||||
assert eic() == "deserialized"
|
||||
for expected in ["deserialized", "was", "i"]:
|
||||
assert eic.f.pop() == expected
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_getstate(self):
|
||||
class FooBar4321(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(FooBar4321, self).__init__()
|
||||
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
||||
|
||||
def forward(self):
|
||||
return self.f.top()
|
||||
|
||||
inst = FooBar4321()
|
||||
scripted = torch.jit.script(inst)
|
||||
eic = self.getExportImportCopy(scripted)
|
||||
# NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to
|
||||
# return {1, 3, 3, 7}. I tried to make this actually depend on the
|
||||
# values at instantiation in the test with some transformation, but
|
||||
# because it seems we serialize/deserialize multiple times, that
|
||||
# transformation isn't as you would it expect it to be.
|
||||
assert eic() == 7
|
||||
for expected in [7, 3, 3, 1]:
|
||||
assert eic.f.pop() == expected
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_tracing(self):
|
||||
class TryTracing(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TryTracing, self).__init__()
|
||||
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
||||
|
||||
def forward(self):
|
||||
return torch.ops._TorchScriptTesting.take_an_instance(self.f)
|
||||
|
||||
traced = torch.jit.trace(TryTracing(), ())
|
||||
self.assertEqual(torch.zeros(4, 4), traced())
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_tracing_nested(self):
|
||||
class TryTracingNest(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TryTracingNest, self).__init__()
|
||||
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
||||
|
||||
class TryTracing123(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TryTracing123, self).__init__()
|
||||
self.nest = TryTracingNest()
|
||||
|
||||
def forward(self):
|
||||
return torch.ops._TorchScriptTesting.take_an_instance(self.nest.f)
|
||||
|
||||
traced = torch.jit.trace(TryTracing123(), ())
|
||||
self.assertEqual(torch.zeros(4, 4), traced())
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_pickle_serialization(self):
|
||||
nt = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
||||
b = io.BytesIO()
|
||||
torch.save(nt, b)
|
||||
b.seek(0)
|
||||
nt_loaded = torch.load(b)
|
||||
for exp in [7, 3, 3, 1]:
|
||||
self.assertEqual(nt_loaded.pop(), exp)
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_instantiate_missing_class(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class \'foo.IDontExist\', but it does not exist!'):
|
||||
torch.classes.foo.IDontExist(3, 4, 5)
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_optional_explicit_attr(self):
|
||||
class TorchBindOptionalExplicitAttr(torch.nn.Module):
|
||||
foo : Optional[torch.classes._TorchScriptTesting._StackString]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.foo = torch.classes._TorchScriptTesting._StackString(["test"])
|
||||
|
||||
def forward(self) -> str:
|
||||
foo_obj = self.foo
|
||||
if foo_obj is not None:
|
||||
return foo_obj.pop()
|
||||
else:
|
||||
return '<None>'
|
||||
|
||||
mod = TorchBindOptionalExplicitAttr()
|
||||
scripted = torch.jit.script(mod)
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_str(self):
|
||||
foo = torch.classes._TorchScriptTesting._StackString(["foo", "bar", "baz"])
|
||||
self.assertEqual(str(foo), "[foo, bar, baz]")
|
||||
|
||||
def test_module_str(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
@ -4832,12 +4607,6 @@ def foo(x):
|
|||
f = torch.jit.script(Foo())
|
||||
self.assertEqual('ScriptObject', str(f._c))
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_magic_unimplemented(self):
|
||||
foo = torch.classes._TorchScriptTesting._StackString(["foo", "bar", "baz"])
|
||||
with self.assertRaises(NotImplementedError):
|
||||
foo[3]
|
||||
|
||||
def _test_lower_graph_impl(self, model, data):
|
||||
model.qconfig = torch.quantization.default_qconfig
|
||||
model = torch.quantization.prepare(model)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user