mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Following the creation of effect tokens (https://github.com/pytorch/pytorch/pull/120296), we want to now add support for these tokens in export because the calling/returning convention has changed. The inputs are now `(tokens, params, buffers, constants, user_inputs)` and the outputs are `(tokens, buffer_mutations, user_mutations, user_outputs)`. The graph looks something like: ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %attr : [num_users=2] = placeholder[target=attr] %arg1_1 : [num_users=2] = placeholder[target=arg1_1] %with_effects : [num_users=2] = call_function[target=torch._higher_order_ops.effects.with_effects](args = (%arg0_1, _TorchScriptTesting.takes_foo.default, %attr, %arg1_1), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {}) %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 1), kwargs = {}) %with_effects_1 : [num_users=2] = call_function[target=torch._higher_order_ops.effects.with_effects](args = (%getitem, _TorchScriptTesting.takes_foo.default, %attr, %getitem_1), kwargs = {}) %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects_1, 0), kwargs = {}) %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects_1, 1), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg1_1, %getitem_3), kwargs = {}) return (getitem_2, add) ``` During unlifting, we will first remove the tokens and with_effect calls using the `remove_effect_tokens` pass. (cc @SherlockNoMad on the pass to remove tokens). This is so that this won't change the calling conventions when retracing. The graph after unlifting looks something like: ``` graph(): %attr_1 : [num_users=2] = get_attr[target=attr] %arg1_1 : [num_users=2] = placeholder[target=arg1_1] %takes_foo_default_1 : [num_users=1] = call_function[target=torch.ops._TorchScriptTesting.takes_foo.default](args = (%attr_1, %arg1_1), kwargs = {}) %takes_foo_default : [num_users=1] = call_function[target=torch.ops._TorchScriptTesting.takes_foo.default](args = (%attr_1, %takes_foo_default_1), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg1_1, %takes_foo_default), kwargs = {}) return (add,) ``` Serialization support will be added in a followup. Note: tokens only affect custom ops that take in ScriptObjects, not ScriptObject methods yet. Differential Revision: [D54639390](https://our.internmc.facebook.com/intern/diff/D54639390) Pull Request resolved: https://github.com/pytorch/pytorch/pull/121424 Approved by: https://github.com/tugsbayasgalan
184 lines
6.2 KiB
Python
184 lines
6.2 KiB
Python
# Owner(s): ["oncall: export"]
|
|
import unittest
|
|
|
|
import torch
|
|
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
|
from torch.export import export
|
|
from torch.testing._internal.common_utils import (
|
|
find_library_location,
|
|
IS_FBCODE,
|
|
IS_MACOS,
|
|
IS_SANDCASTLE,
|
|
IS_WINDOWS,
|
|
run_tests,
|
|
skipIfTorchDynamo,
|
|
TestCase,
|
|
)
|
|
|
|
|
|
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
|
|
class TestExportTorchbind(TestCase):
|
|
def setUp(self):
|
|
if IS_MACOS:
|
|
raise unittest.SkipTest("non-portable load_library call used in test")
|
|
elif IS_SANDCASTLE or IS_FBCODE:
|
|
torch.ops.load_library(
|
|
"//caffe2/test/cpp/jit:test_custom_class_registrations"
|
|
)
|
|
elif IS_WINDOWS:
|
|
lib_file_path = find_library_location("torchbind_test.dll")
|
|
torch.ops.load_library(str(lib_file_path))
|
|
else:
|
|
lib_file_path = find_library_location("libtorchbind_test.so")
|
|
torch.ops.load_library(str(lib_file_path))
|
|
|
|
def _test_export_same_as_eager(self, f, args, kwargs=None, strict=True):
|
|
kwargs = kwargs or {}
|
|
with enable_torchbind_tracing():
|
|
exported_program = export(f, args, kwargs, strict=strict)
|
|
reversed_kwargs = {key: kwargs[key] for key in reversed(kwargs)}
|
|
self.assertEqual(exported_program.module()(*args, **kwargs), f(*args, **kwargs))
|
|
self.assertEqual(
|
|
exported_program.module()(*args, **reversed_kwargs),
|
|
f(*args, **reversed_kwargs),
|
|
)
|
|
|
|
def test_none(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x, n):
|
|
return x + self.attr.add_tensor(x)
|
|
|
|
self._test_export_same_as_eager(
|
|
MyModule(), (torch.ones(2, 3), None), strict=False
|
|
)
|
|
|
|
def test_attribute(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
return x + self.attr.add_tensor(x)
|
|
|
|
self._test_export_same_as_eager(MyModule(), (torch.ones(2, 3),), strict=False)
|
|
|
|
def test_attribute_as_custom_op_argument(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
return x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
|
|
|
|
self._test_export_same_as_eager(MyModule(), (torch.ones(2, 3),), strict=False)
|
|
|
|
def test_input(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, cc):
|
|
return x + cc.add_tensor(x)
|
|
|
|
cc = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
self._test_export_same_as_eager(
|
|
MyModule(), (torch.ones(2, 3), cc), strict=False
|
|
)
|
|
|
|
def test_input_as_custom_op_argument(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, cc):
|
|
return x + torch.ops._TorchScriptTesting.takes_foo(cc, x)
|
|
|
|
cc = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
self._test_export_same_as_eager(
|
|
MyModule(), (torch.ones(2, 3), cc), strict=False
|
|
)
|
|
|
|
def test_unlift_custom_obj(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
a = torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
|
|
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, a)
|
|
return x + b
|
|
|
|
m = MyModule()
|
|
input = torch.ones(2, 3)
|
|
with enable_torchbind_tracing():
|
|
ep = torch.export.export(m, (input,), strict=False)
|
|
|
|
unlifted = ep.module()
|
|
self.assertEqual(m(input), unlifted(input))
|
|
|
|
with enable_torchbind_tracing():
|
|
ep2 = torch.export.export(unlifted, (input,), strict=False)
|
|
|
|
self.assertEqual(m(input), ep2.module()(input))
|
|
|
|
def test_custom_obj_list_out(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x)
|
|
y = a[0] + a[1] + a[2]
|
|
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
|
|
return x + b
|
|
|
|
m = MyModule()
|
|
input = torch.ones(2, 3)
|
|
with enable_torchbind_tracing():
|
|
ep = torch.export.export(m, (input,), strict=False)
|
|
|
|
unlifted = ep.module()
|
|
self.assertEqual(m(input), unlifted(input))
|
|
|
|
with enable_torchbind_tracing():
|
|
ep2 = torch.export.export(unlifted, (input,), strict=False)
|
|
|
|
self.assertEqual(m(input), ep2.module()(input))
|
|
|
|
def test_custom_obj_tuple_out(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
|
|
y = a[0] + a[1]
|
|
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
|
|
return x + b
|
|
|
|
m = MyModule()
|
|
input = torch.ones(2, 3)
|
|
with enable_torchbind_tracing():
|
|
ep = torch.export.export(m, (input,), strict=False)
|
|
|
|
unlifted = ep.module()
|
|
self.assertEqual(m(input), unlifted(input))
|
|
|
|
with enable_torchbind_tracing():
|
|
ep2 = torch.export.export(unlifted, (input,), strict=False)
|
|
|
|
self.assertEqual(m(input), ep2.module()(input))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|