mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: note: breaking the original diff D55225818 into 3 parts (top-level renaming, higher-order-op subgraphs, constant input de/serialization) because of its size. Stacked PR to restore original names to placeholder nodes, replacing the default names arg0_1, arg1_1, ... This PR supports constant argument placeholder (e.g. forward(self, x, y=1)) names and de/serialization, by adding a name field for ConstantArguments in the graph signature, and ConstantInputSpec in the input specs for serialization. Test Plan: verification checks on placeholder names for all export() calls, unit test in test/export/test_export.py Differential Revision: D55506949 Pull Request resolved: https://github.com/pytorch/pytorch/pull/123590 Approved by: https://github.com/angelayi, https://github.com/zhxchen17
628 lines
25 KiB
Python
628 lines
25 KiB
Python
# Owner(s): ["oncall: export"]
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
|
from torch._library.fake_class_registry import FakeScriptObject
|
|
from torch.export import export
|
|
from torch.export._trace import _export
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.testing._internal.common_utils import (
|
|
find_library_location,
|
|
instantiate_parametrized_tests,
|
|
IS_FBCODE,
|
|
IS_MACOS,
|
|
IS_SANDCASTLE,
|
|
IS_WINDOWS,
|
|
parametrize,
|
|
run_tests,
|
|
skipIfTorchDynamo,
|
|
TestCase,
|
|
)
|
|
from torch.testing._internal.torchbind_impls import register_fake_operators
|
|
|
|
|
|
def load_torchbind_test_lib():
|
|
if IS_SANDCASTLE or IS_FBCODE:
|
|
torch.ops.load_library("//caffe2/test/cpp/jit:test_custom_class_registrations")
|
|
elif IS_MACOS:
|
|
raise unittest.SkipTest("non-portable load_library call used in test")
|
|
else:
|
|
lib_file_path = find_library_location("libtorchbind_test.so")
|
|
if IS_WINDOWS:
|
|
lib_file_path = find_library_location("torchbind_test.dll")
|
|
torch.ops.load_library(str(lib_file_path))
|
|
|
|
register_fake_operators()
|
|
|
|
|
|
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
|
|
class TestExportTorchbind(TestCase):
|
|
def setUp(self):
|
|
load_torchbind_test_lib()
|
|
|
|
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
|
class FakeFoo:
|
|
def __init__(self, x: int, y: int):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
@classmethod
|
|
def from_real(cls, foo):
|
|
(x, y), _ = foo.__getstate__()
|
|
return cls(x, y)
|
|
|
|
def add_tensor(self, z):
|
|
return (self.x + self.y) * z
|
|
|
|
test = self
|
|
test.tq_push_counter = 0
|
|
test.tq_pop_counter = 0
|
|
test.tq_size_counter = 0
|
|
|
|
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
|
|
class FakeTensorQueue:
|
|
def __init__(self, q):
|
|
self.queue = q
|
|
|
|
@classmethod
|
|
def from_real(cls, real_tq):
|
|
ctx = torch.library.get_ctx()
|
|
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
|
|
return cls(fake_queue)
|
|
|
|
def push(self, x):
|
|
test.tq_push_counter += 1
|
|
self.queue.append(x)
|
|
|
|
def pop(self):
|
|
test.tq_pop_counter += 1
|
|
return self.queue.pop(0)
|
|
|
|
def size(self):
|
|
test.tq_size_counter += 1
|
|
return len(self.queue)
|
|
|
|
def tearDown(self):
|
|
torch._library.fake_class_registry.deregister_fake_class(
|
|
"_TorchScriptTesting::_Foo"
|
|
)
|
|
torch._library.fake_class_registry.deregister_fake_class(
|
|
"_TorchScriptTesting::_TensorQueue"
|
|
)
|
|
|
|
def _assertEqualSkipScriptObject(self, exp, actual):
|
|
flat_exp = pytree.tree_leaves(exp)
|
|
flat_actual = pytree.tree_leaves(actual)
|
|
self.assertEqual(len(flat_exp), len(flat_actual))
|
|
for a, b in zip(flat_exp, flat_actual):
|
|
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
|
|
continue
|
|
self.assertEqual(a, b)
|
|
|
|
def _test_export_same_as_eager(
|
|
self, f, args, kwargs=None, strict=True, pre_dispatch=False
|
|
):
|
|
kwargs = kwargs or {}
|
|
|
|
def export_wrapper(f, args, kwargs, strcit, pre_dispatch):
|
|
with enable_torchbind_tracing():
|
|
if pre_dispatch:
|
|
exported_program = _export(
|
|
f, args, kwargs, strict=strict, pre_dispatch=True
|
|
)
|
|
else:
|
|
exported_program = export(f, args, kwargs, strict=strict)
|
|
return exported_program
|
|
|
|
exported_program = export_wrapper(f, args, kwargs, strict, pre_dispatch)
|
|
reversed_kwargs = {key: kwargs[key] for key in reversed(kwargs)}
|
|
unlifted = exported_program.module()
|
|
exp = f(*args, **kwargs)
|
|
self.assertEqual(unlifted(*args, **kwargs), exp)
|
|
self.assertEqual(
|
|
unlifted(*args, **reversed_kwargs),
|
|
exp,
|
|
)
|
|
|
|
# check re-tracing
|
|
retraced_ep = export_wrapper(unlifted, args, kwargs, strict, pre_dispatch)
|
|
self.assertEqual(retraced_ep.module()(*args, **kwargs), exp)
|
|
return exported_program
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_none(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x, n):
|
|
return x + self.attr.add_tensor(x)
|
|
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(),
|
|
(torch.ones(2, 3), None),
|
|
strict=False,
|
|
pre_dispatch=pre_dispatch,
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, arg_0, arg_1):
|
|
x, n, = fx_pytree.tree_flatten_spec(([arg_0, arg_1], {}), self._in_spec)
|
|
attr = self.attr
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None
|
|
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
|
|
return pytree.tree_unflatten((add,), self._out_spec)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, obj_attr, x, n):
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(obj_attr, 'add_tensor', x); obj_attr = None
|
|
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
|
|
return (add,)""",
|
|
)
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_attribute(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
return x + self.attr.add_tensor(x)
|
|
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, arg_0):
|
|
x, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
|
|
attr = self.attr
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None
|
|
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
|
|
return pytree.tree_unflatten((add,), self._out_spec)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, obj_attr, x):
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(obj_attr, 'add_tensor', x); obj_attr = None
|
|
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
|
|
return (add,)""",
|
|
)
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_attribute_as_custom_op_argument(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
return x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
|
|
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, arg_0):
|
|
x, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
|
|
attr = self.attr
|
|
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x); attr = None
|
|
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
|
|
return pytree.tree_unflatten((add,), self._out_spec)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, token, obj_attr, x):
|
|
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, x); token = obj_attr = None
|
|
getitem = with_effects[0]
|
|
getitem_1 = with_effects[1]; with_effects = None
|
|
add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None
|
|
return (getitem, add)""", # noqa: B950
|
|
)
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_input(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, cc):
|
|
return x + cc.add_tensor(x)
|
|
|
|
cc = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(), (torch.ones(2, 3), cc), strict=False, pre_dispatch=pre_dispatch
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, arg_0, arg_1):
|
|
x, cc, = fx_pytree.tree_flatten_spec(([arg_0, arg_1], {}), self._in_spec)
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(cc, 'add_tensor', x); cc = None
|
|
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
|
|
return pytree.tree_unflatten((add,), self._out_spec)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, x, cc):
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(cc, 'add_tensor', x); cc = None
|
|
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
|
|
return (add,)""",
|
|
)
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_input_as_custom_op_argument(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, cc):
|
|
return x + torch.ops._TorchScriptTesting.takes_foo(cc, x)
|
|
|
|
cc = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(), (torch.ones(2, 3), cc), strict=False, pre_dispatch=pre_dispatch
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, arg_0, arg_1):
|
|
x, cc, = fx_pytree.tree_flatten_spec(([arg_0, arg_1], {}), self._in_spec)
|
|
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(cc, x); cc = None
|
|
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
|
|
return pytree.tree_unflatten((add,), self._out_spec)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, token, x, cc):
|
|
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, cc, x); token = cc = None
|
|
getitem = with_effects[0]
|
|
getitem_1 = with_effects[1]; with_effects = None
|
|
add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None
|
|
return (getitem, add)""", # noqa: B950
|
|
)
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_unlift_custom_obj(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
a = torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
|
|
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, a)
|
|
return x + b
|
|
|
|
input = torch.ones(2, 3)
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, arg_0):
|
|
x, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
|
|
attr = self.attr
|
|
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
|
|
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None
|
|
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
|
|
return pytree.tree_unflatten((add,), self._out_spec)""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, token, obj_attr, x):
|
|
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, x); token = None
|
|
getitem = with_effects[0]
|
|
getitem_1 = with_effects[1]; with_effects = None
|
|
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, getitem_1); getitem = obj_attr = getitem_1 = None
|
|
getitem_2 = with_effects_1[0]
|
|
getitem_3 = with_effects_1[1]; with_effects_1 = None
|
|
add = torch.ops.aten.add.Tensor(x, getitem_3); x = getitem_3 = None
|
|
return (getitem_2, add)""", # noqa: B950
|
|
)
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_custom_obj_list_out(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x)
|
|
y = a[0] + a[1] + a[2]
|
|
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
|
|
return x + b
|
|
|
|
input = torch.ones(2, 3)
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, arg_0):
|
|
x, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
|
|
attr = self.attr
|
|
takes_foo_list_return_default = torch.ops._TorchScriptTesting.takes_foo_list_return.default(attr, x)
|
|
getitem_2 = takes_foo_list_return_default[0]
|
|
getitem_3 = takes_foo_list_return_default[1]
|
|
getitem_4 = takes_foo_list_return_default[2]; takes_foo_list_return_default = None
|
|
add = torch.ops.aten.add.Tensor(getitem_2, getitem_3); getitem_2 = getitem_3 = None
|
|
add_1 = torch.ops.aten.add.Tensor(add, getitem_4); add = getitem_4 = None
|
|
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, add_1); attr = add_1 = None
|
|
add_2 = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
|
|
return pytree.tree_unflatten((add_2,), self._out_spec)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, token, obj_attr, x):
|
|
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_list_return.default, obj_attr, x); token = None
|
|
getitem = with_effects[0]
|
|
getitem_1 = with_effects[1]; with_effects = None
|
|
getitem_2 = getitem_1[0]
|
|
getitem_3 = getitem_1[1]
|
|
getitem_4 = getitem_1[2]; getitem_1 = None
|
|
add = torch.ops.aten.add.Tensor(getitem_2, getitem_3); getitem_2 = getitem_3 = None
|
|
add_1 = torch.ops.aten.add.Tensor(add, getitem_4); add = getitem_4 = None
|
|
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, add_1); getitem = obj_attr = add_1 = None
|
|
getitem_5 = with_effects_1[0]
|
|
getitem_6 = with_effects_1[1]; with_effects_1 = None
|
|
add_2 = torch.ops.aten.add.Tensor(x, getitem_6); x = getitem_6 = None
|
|
return (getitem_5, add_2)""", # noqa: B950
|
|
)
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_custom_obj_tuple_out(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
|
|
y = a[0] + a[1]
|
|
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
|
|
return x + b
|
|
|
|
input = torch.ones(2, 3)
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, arg_0):
|
|
x, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
|
|
attr = self.attr
|
|
takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(attr, x)
|
|
getitem_1 = takes_foo_tuple_return_default[0]
|
|
getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None
|
|
add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None
|
|
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, add); attr = add = None
|
|
add_1 = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
|
|
return pytree.tree_unflatten((add_1,), self._out_spec)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, token, obj_attr, x):
|
|
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, obj_attr, x); token = None
|
|
getitem = with_effects[0]
|
|
getitem_1 = with_effects[1]
|
|
getitem_2 = with_effects[2]; with_effects = None
|
|
add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None
|
|
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, add); getitem = obj_attr = add = None
|
|
getitem_3 = with_effects_1[0]
|
|
getitem_4 = with_effects_1[1]; with_effects_1 = None
|
|
add_1 = torch.ops.aten.add.Tensor(x, getitem_4); x = getitem_4 = None
|
|
return (getitem_3, add_1)""", # noqa: B950
|
|
)
|
|
|
|
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
|
|
def test_make_fx_tensor_queue_methods(self, make_fx_tracing_mode):
|
|
test = self
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 2)
|
|
self.check_tq_is_fake = True
|
|
|
|
def forward(self, tq, x):
|
|
if self.check_tq_is_fake:
|
|
test.assertTrue(isinstance(tq, FakeScriptObject))
|
|
tq.push(x.cos())
|
|
tq.push(x.sin())
|
|
x_cos = tq.pop() + tq.size()
|
|
x_sin = tq.pop() - tq.size()
|
|
return x_sin, x_cos, tq
|
|
|
|
mod = Model()
|
|
tq = torch.classes._TorchScriptTesting._TensorQueue(
|
|
torch.empty(
|
|
0,
|
|
).fill_(-1)
|
|
)
|
|
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
|
|
torch.empty(
|
|
0,
|
|
).fill_(-1)
|
|
)
|
|
x = torch.ones(2, 3)
|
|
gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
|
|
self.assertEqual(self.tq_push_counter, 2)
|
|
self.assertEqual(self.tq_pop_counter, 2)
|
|
self.assertEqual(self.tq_size_counter, 2)
|
|
self.assertEqual(tq.size(), 0)
|
|
self.assertExpectedInline(
|
|
gm.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1):
|
|
cos = torch.ops.aten.cos.default(arg1_1)
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'push', cos); cos = None
|
|
sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
|
|
call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'push', sin); sin = None
|
|
call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
|
|
call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
|
|
add = torch.ops.aten.add.Tensor(call_torchbind_2, 1); call_torchbind_2 = None
|
|
call_torchbind_4 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
|
|
call_torchbind_5 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
|
|
sub = torch.ops.aten.sub.Tensor(call_torchbind_4, 0); call_torchbind_4 = None
|
|
return (sub, add, arg0_1)
|
|
""",
|
|
)
|
|
mod.check_tq_is_fake = False
|
|
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
|
|
|
|
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
|
|
def test_make_fx_tensor_queue_methods_fakify_internal_states(
|
|
self, make_fx_tracing_mode
|
|
):
|
|
test = self
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 2)
|
|
self.check_tq_is_fake = True
|
|
self.current_test = test
|
|
|
|
def forward(self, tq, x):
|
|
if self.check_tq_is_fake:
|
|
self.current_test.assertTrue(isinstance(tq, FakeScriptObject))
|
|
x_cos = tq.pop() + tq.size() + x
|
|
x_sin = tq.pop() - tq.size() + x
|
|
return x_sin, x_cos, tq
|
|
|
|
mod = Model()
|
|
tq = torch.classes._TorchScriptTesting._TensorQueue(
|
|
torch.empty(
|
|
0,
|
|
).fill_(-1)
|
|
)
|
|
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
|
|
torch.empty(
|
|
0,
|
|
).fill_(-1)
|
|
)
|
|
for _ in range(2):
|
|
tq.push(torch.ones(2, 3))
|
|
tq1.push(torch.ones(2, 3))
|
|
x = torch.ones(2, 3)
|
|
prev_size = tq.size()
|
|
gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
|
|
self.assertEqual(self.tq_push_counter, 0)
|
|
self.assertEqual(self.tq_pop_counter, 2)
|
|
self.assertEqual(self.tq_size_counter, 2)
|
|
self.assertEqual(tq.size(), prev_size)
|
|
self.assertExpectedInline(
|
|
gm.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1):
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
|
|
call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
|
|
add = torch.ops.aten.add.Tensor(call_torchbind, 1); call_torchbind = None
|
|
add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = None
|
|
call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
|
|
call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
|
|
sub = torch.ops.aten.sub.Tensor(call_torchbind_2, 0); call_torchbind_2 = None
|
|
add_2 = torch.ops.aten.add.Tensor(sub, arg1_1); sub = arg1_1 = None
|
|
return (add_2, add_1, arg0_1)
|
|
""",
|
|
)
|
|
# turn off tq type checking in eager execution
|
|
mod.check_tq_is_fake = False
|
|
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
|
|
self.assertEqual(tq.size(), 0)
|
|
self.assertEqual(tq1.size(), 0)
|
|
|
|
|
|
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
|
|
class TestRegisterFakeClass(TestCase):
|
|
def setUp(self):
|
|
load_torchbind_test_lib()
|
|
|
|
def tearDown(self):
|
|
torch._library.fake_class_registry.global_fake_class_registry.clear()
|
|
|
|
def test_register_fake_class_no_torch_bind_class(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Tried to instantiate class"):
|
|
|
|
@torch._library.register_fake_class("_TorchScriptTesting::NOT_A_VALID_NAME")
|
|
class Invalid:
|
|
pass
|
|
|
|
def test_register_fake_class_no_from_real(self):
|
|
with self.assertRaisesRegex(RuntimeError, "define a classmethod from_real"):
|
|
|
|
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
|
class InvalidFakeFoo:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def test_register_fake_class_from_real_not_classmethod(self):
|
|
with self.assertRaisesRegex(RuntimeError, "is not a classmethod"):
|
|
|
|
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
|
class FakeFoo:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
def from_real(self, foo_obj):
|
|
x, y = foo_obj.__getstate__()
|
|
return FakeFoo(x, y)
|
|
|
|
def test_register_fake_class_valid(self):
|
|
class FakeFoo:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
@classmethod
|
|
def from_real(cls, foo_obj):
|
|
x, y = foo_obj.__getstate__()
|
|
return cls(x, y)
|
|
|
|
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
|
|
|
|
def test_register_fake_class_duplicate_registration(self):
|
|
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
|
class FakeFoo:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
@classmethod
|
|
def from_real(cls, foo_obj):
|
|
x, y = foo_obj.__getstate__()
|
|
return cls(x, y)
|
|
|
|
with self.assertWarnsRegex(UserWarning, "already registered"):
|
|
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
|
|
|
|
|
|
instantiate_parametrized_tests(TestExportTorchbind)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|