mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR fakify ScriptObject inputs and attributes in export non-strict mode by default. The basic idea is to `only fakify the script object during tracing (i.e. aot_export)`. After we get the traced graph module, eagerly executing, serializing, or running more passes will use the real script objects. This is essentially treating the script object as constant tensor. Concretely, we 1. fakify all the script object inputs, and module attributes (gathered by constant_attrs). 2. patch the module's attributes with fakified script object 3. right after aot_export, remove the patching (to avoid changing the original module) then modify the exported graph module's attribute to real script object. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124239 Approved by: https://github.com/zou3519
882 lines
36 KiB
Python
882 lines
36 KiB
Python
# Owner(s): ["oncall: export"]
|
|
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch._functorch.aot_autograd import aot_export_module
|
|
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
|
from torch._library.fake_class_registry import FakeScriptObject
|
|
from torch.export import export
|
|
from torch.export._trace import _export
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
skipIfTorchDynamo,
|
|
TestCase,
|
|
)
|
|
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
|
|
|
|
|
|
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
|
|
class TestExportTorchbind(TestCase):
|
|
def setUp(self):
|
|
init_torchbind_implementations()
|
|
|
|
test = self
|
|
test.tq_push_counter = 0
|
|
test.tq_pop_counter = 0
|
|
test.tq_size_counter = 0
|
|
test.foo_add_tensor_counter = 0
|
|
|
|
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
|
class FakeFoo:
|
|
def __init__(self, x: int, y: int):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
@classmethod
|
|
def from_real(cls, foo):
|
|
(x, y), _ = foo.__getstate__()
|
|
return cls(x, y)
|
|
|
|
def add_tensor(self, z):
|
|
test.foo_add_tensor_counter += 1
|
|
return (self.x + self.y) * z
|
|
|
|
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
|
|
class FakeTensorQueue:
|
|
def __init__(self, q):
|
|
self.queue = q
|
|
|
|
@classmethod
|
|
def from_real(cls, real_tq):
|
|
ctx = torch.library.get_ctx()
|
|
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
|
|
return cls(fake_queue)
|
|
|
|
def push(self, x):
|
|
test.tq_push_counter += 1
|
|
self.queue.append(x)
|
|
|
|
def pop(self):
|
|
test.tq_pop_counter += 1
|
|
return self.queue.pop(0)
|
|
|
|
def size(self):
|
|
test.tq_size_counter += 1
|
|
return len(self.queue)
|
|
|
|
self.torch_bind_ops = [
|
|
torch.ops._TorchScriptTesting.takes_foo,
|
|
torch.ops._TorchScriptTesting.takes_foo_python_meta,
|
|
torch.ops._TorchScriptTesting.takes_foo_list_return,
|
|
torch.ops._TorchScriptTesting.takes_foo_tuple_return,
|
|
torch.ops._TorchScriptTesting.take_an_instance,
|
|
torch.ops._TorchScriptTesting.take_an_instance_inferred,
|
|
torch.ops._TorchScriptTesting.takes_foo_cia,
|
|
torch.ops._TorchScriptTesting.queue_pop,
|
|
torch.ops._TorchScriptTesting.queue_push,
|
|
torch.ops._TorchScriptTesting.queue_size,
|
|
]
|
|
|
|
def tearDown(self):
|
|
torch._library.fake_class_registry.deregister_fake_class(
|
|
"_TorchScriptTesting::_Foo"
|
|
)
|
|
torch._library.fake_class_registry.deregister_fake_class(
|
|
"_TorchScriptTesting::_TensorQueue"
|
|
)
|
|
|
|
def _assertEqualSkipScriptObject(self, exp, actual):
|
|
flat_exp = pytree.tree_leaves(exp)
|
|
flat_actual = pytree.tree_leaves(actual)
|
|
self.assertEqual(len(flat_exp), len(flat_actual))
|
|
for a, b in zip(flat_exp, flat_actual):
|
|
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
|
|
continue
|
|
self.assertEqual(a, b)
|
|
|
|
def _test_export_same_as_eager(
|
|
self, f, args, kwargs=None, strict=True, pre_dispatch=False
|
|
):
|
|
kwargs = kwargs or {}
|
|
|
|
def export_wrapper(f, args, kwargs, strcit, pre_dispatch):
|
|
with enable_torchbind_tracing():
|
|
if pre_dispatch:
|
|
exported_program = _export(
|
|
f, args, kwargs, strict=strict, pre_dispatch=True
|
|
)
|
|
else:
|
|
exported_program = export(f, args, kwargs, strict=strict)
|
|
return exported_program
|
|
|
|
exported_program = export_wrapper(f, args, kwargs, strict, pre_dispatch)
|
|
reversed_kwargs = {key: kwargs[key] for key in reversed(kwargs)}
|
|
unlifted = exported_program.module()
|
|
exp = f(*args, **kwargs)
|
|
self.assertEqual(unlifted(*args, **kwargs), exp)
|
|
self.assertEqual(
|
|
unlifted(*args, **reversed_kwargs),
|
|
exp,
|
|
)
|
|
|
|
# check re-tracing
|
|
retraced_ep = export_wrapper(unlifted, args, kwargs, strict, pre_dispatch)
|
|
self.assertEqual(retraced_ep.module()(*args, **kwargs), exp)
|
|
return exported_program
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_none(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x, n):
|
|
return x + self.attr.add_tensor(x)
|
|
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(),
|
|
(torch.ones(2, 3), None),
|
|
strict=False,
|
|
pre_dispatch=pre_dispatch,
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, x, n):
|
|
x, n, = fx_pytree.tree_flatten_spec(([x, n], {}), self._in_spec)
|
|
attr = self.attr
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None
|
|
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
|
|
return pytree.tree_unflatten((add,), self._out_spec)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, obj_attr, x, n):
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(obj_attr, 'add_tensor', x); obj_attr = None
|
|
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
|
|
return (add,)""",
|
|
)
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_attribute(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
return x + self.attr.add_tensor(x)
|
|
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
attr = self.attr
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None
|
|
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
|
|
return pytree.tree_unflatten((add,), self._out_spec)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, obj_attr, x):
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(obj_attr, 'add_tensor', x); obj_attr = None
|
|
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
|
|
return (add,)""",
|
|
)
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_attribute_as_custom_op_argument(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
return x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
|
|
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
attr = self.attr
|
|
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x); attr = None
|
|
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
|
|
return pytree.tree_unflatten((add,), self._out_spec)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, token, obj_attr, x):
|
|
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, x); token = obj_attr = None
|
|
getitem = with_effects[0]
|
|
getitem_1 = with_effects[1]; with_effects = None
|
|
add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None
|
|
return (getitem, add)""", # noqa: B950
|
|
)
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_input(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, cc):
|
|
return x + cc.add_tensor(x)
|
|
|
|
cc = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(), (torch.ones(2, 3), cc), strict=False, pre_dispatch=pre_dispatch
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, x, cc):
|
|
x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec)
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(cc, 'add_tensor', x); cc = None
|
|
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
|
|
return pytree.tree_unflatten((add,), self._out_spec)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, x, cc):
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(cc, 'add_tensor', x); cc = None
|
|
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
|
|
return (add,)""",
|
|
)
|
|
# aot_export_function runs the program twice
|
|
# in run_functionalized_fw_and_collect_metadata and create_aot_dispatcher_function
|
|
# We also have a re-tracing test, which doubles the count.
|
|
self.assertEqual(self.foo_add_tensor_counter, 4)
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_input_as_custom_op_argument(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, cc):
|
|
return x + torch.ops._TorchScriptTesting.takes_foo(cc, x)
|
|
|
|
cc = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
del torch.ops._TorchScriptTesting.takes_foo.default.py_kernels[
|
|
torch._C.DispatchKey.Meta
|
|
]
|
|
torch.ops._TorchScriptTesting.takes_foo.default._dispatch_cache.clear()
|
|
# Even though a C++ implementation for takes_foo.default is registered,
|
|
# we still need the python implementation for takes_foo.default to trace with FakeFoo.
|
|
with self.assertRaisesRegex(RuntimeError, "no python implementation is found"):
|
|
self._test_export_same_as_eager(
|
|
MyModule(),
|
|
(torch.ones(2, 3), cc),
|
|
strict=False,
|
|
pre_dispatch=pre_dispatch,
|
|
)
|
|
|
|
torch.ops._TorchScriptTesting.takes_foo.default.py_impl(
|
|
torch._C.DispatchKey.Meta
|
|
)(lambda cc, x: cc.add_tensor(x))
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(),
|
|
(torch.ones(2, 3), cc),
|
|
strict=False,
|
|
pre_dispatch=pre_dispatch,
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, x, cc):
|
|
x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec)
|
|
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(cc, x); cc = None
|
|
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
|
|
return pytree.tree_unflatten((add,), self._out_spec)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, token, x, cc):
|
|
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, cc, x); token = cc = None
|
|
getitem = with_effects[0]
|
|
getitem_1 = with_effects[1]; with_effects = None
|
|
add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None
|
|
return (getitem, add)""", # noqa: B950
|
|
)
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_unlift_custom_obj(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
a = torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
|
|
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, a)
|
|
return x + b
|
|
|
|
input = torch.ones(2, 3)
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
attr = self.attr
|
|
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
|
|
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None
|
|
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
|
|
return pytree.tree_unflatten((add,), self._out_spec)""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, token, obj_attr, x):
|
|
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, x); token = None
|
|
getitem = with_effects[0]
|
|
getitem_1 = with_effects[1]; with_effects = None
|
|
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, getitem_1); getitem = obj_attr = getitem_1 = None
|
|
getitem_2 = with_effects_1[0]
|
|
getitem_3 = with_effects_1[1]; with_effects_1 = None
|
|
add = torch.ops.aten.add.Tensor(x, getitem_3); x = getitem_3 = None
|
|
return (getitem_2, add)""", # noqa: B950
|
|
)
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_custom_obj_list_out(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x)
|
|
y = a[0] + a[1] + a[2]
|
|
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
|
|
return x + b
|
|
|
|
input = torch.ones(2, 3)
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
attr = self.attr
|
|
takes_foo_list_return_default = torch.ops._TorchScriptTesting.takes_foo_list_return.default(attr, x)
|
|
getitem_2 = takes_foo_list_return_default[0]
|
|
getitem_3 = takes_foo_list_return_default[1]
|
|
getitem_4 = takes_foo_list_return_default[2]; takes_foo_list_return_default = None
|
|
add = torch.ops.aten.add.Tensor(getitem_2, getitem_3); getitem_2 = getitem_3 = None
|
|
add_1 = torch.ops.aten.add.Tensor(add, getitem_4); add = getitem_4 = None
|
|
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, add_1); attr = add_1 = None
|
|
add_2 = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
|
|
return pytree.tree_unflatten((add_2,), self._out_spec)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, token, obj_attr, x):
|
|
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_list_return.default, obj_attr, x); token = None
|
|
getitem = with_effects[0]
|
|
getitem_1 = with_effects[1]; with_effects = None
|
|
getitem_2 = getitem_1[0]
|
|
getitem_3 = getitem_1[1]
|
|
getitem_4 = getitem_1[2]; getitem_1 = None
|
|
add = torch.ops.aten.add.Tensor(getitem_2, getitem_3); getitem_2 = getitem_3 = None
|
|
add_1 = torch.ops.aten.add.Tensor(add, getitem_4); add = getitem_4 = None
|
|
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, add_1); getitem = obj_attr = add_1 = None
|
|
getitem_5 = with_effects_1[0]
|
|
getitem_6 = with_effects_1[1]; with_effects_1 = None
|
|
add_2 = torch.ops.aten.add.Tensor(x, getitem_6); x = getitem_6 = None
|
|
return (getitem_5, add_2)""", # noqa: B950
|
|
)
|
|
|
|
@parametrize("pre_dispatch", [True, False])
|
|
def test_custom_obj_tuple_out(self, pre_dispatch):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
|
|
y = a[0] + a[1]
|
|
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
|
|
return x + b
|
|
|
|
input = torch.ones(2, 3)
|
|
ep = self._test_export_same_as_eager(
|
|
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.module().code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
attr = self.attr
|
|
takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(attr, x)
|
|
getitem_1 = takes_foo_tuple_return_default[0]
|
|
getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None
|
|
add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None
|
|
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, add); attr = add = None
|
|
add_1 = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
|
|
return pytree.tree_unflatten((add_1,), self._out_spec)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
ep.graph_module.code.strip(),
|
|
"""\
|
|
def forward(self, token, obj_attr, x):
|
|
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, obj_attr, x); token = None
|
|
getitem = with_effects[0]
|
|
getitem_1 = with_effects[1]
|
|
getitem_2 = with_effects[2]; with_effects = None
|
|
add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None
|
|
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, add); getitem = obj_attr = add = None
|
|
getitem_3 = with_effects_1[0]
|
|
getitem_4 = with_effects_1[1]; with_effects_1 = None
|
|
add_1 = torch.ops.aten.add.Tensor(x, getitem_4); x = getitem_4 = None
|
|
return (getitem_3, add_1)""", # noqa: B950
|
|
)
|
|
|
|
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
|
|
def test_make_fx_tensor_queue_methods(self, make_fx_tracing_mode):
|
|
test = self
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 2)
|
|
self.check_tq_is_fake = True
|
|
|
|
def forward(self, tq, x):
|
|
if self.check_tq_is_fake:
|
|
test.assertTrue(isinstance(tq, FakeScriptObject))
|
|
tq.push(x.cos())
|
|
tq.push(x.sin())
|
|
x_cos = tq.pop() + tq.size()
|
|
x_sin = tq.pop() - tq.size()
|
|
return x_sin, x_cos, tq
|
|
|
|
mod = Model()
|
|
tq = torch.classes._TorchScriptTesting._TensorQueue(
|
|
torch.empty(
|
|
0,
|
|
).fill_(-1)
|
|
)
|
|
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
|
|
torch.empty(
|
|
0,
|
|
).fill_(-1)
|
|
)
|
|
x = torch.ones(2, 3)
|
|
gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
|
|
self.assertEqual(self.tq_push_counter, 2)
|
|
self.assertEqual(self.tq_pop_counter, 2)
|
|
self.assertEqual(self.tq_size_counter, 2)
|
|
self.assertEqual(tq.size(), 0)
|
|
self.assertExpectedInline(
|
|
gm.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1):
|
|
cos = torch.ops.aten.cos.default(arg1_1)
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'push', cos); cos = None
|
|
sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
|
|
call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'push', sin); sin = None
|
|
call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
|
|
call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
|
|
add = torch.ops.aten.add.Tensor(call_torchbind_2, 1); call_torchbind_2 = None
|
|
call_torchbind_4 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
|
|
call_torchbind_5 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
|
|
sub = torch.ops.aten.sub.Tensor(call_torchbind_4, 0); call_torchbind_4 = None
|
|
return (sub, add, arg0_1)
|
|
""",
|
|
)
|
|
mod.check_tq_is_fake = False
|
|
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
|
|
|
|
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
|
|
def test_make_fx_tensor_queue_methods_fakify_internal_states(
|
|
self, make_fx_tracing_mode
|
|
):
|
|
test = self
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 2)
|
|
self.check_tq_is_fake = True
|
|
self.current_test = test
|
|
|
|
def forward(self, tq, x):
|
|
if self.check_tq_is_fake:
|
|
self.current_test.assertTrue(isinstance(tq, FakeScriptObject))
|
|
x_cos = tq.pop() + tq.size() + x
|
|
x_sin = tq.pop() - tq.size() + x
|
|
return x_sin, x_cos, tq
|
|
|
|
mod = Model()
|
|
tq = torch.classes._TorchScriptTesting._TensorQueue(
|
|
torch.empty(
|
|
0,
|
|
).fill_(-1)
|
|
)
|
|
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
|
|
torch.empty(
|
|
0,
|
|
).fill_(-1)
|
|
)
|
|
for _ in range(2):
|
|
tq.push(torch.ones(2, 3))
|
|
tq1.push(torch.ones(2, 3))
|
|
x = torch.ones(2, 3)
|
|
prev_size = tq.size()
|
|
gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
|
|
self.assertEqual(self.tq_push_counter, 0)
|
|
self.assertEqual(self.tq_pop_counter, 2)
|
|
self.assertEqual(self.tq_size_counter, 2)
|
|
self.assertEqual(tq.size(), prev_size)
|
|
self.assertExpectedInline(
|
|
gm.code.strip("\n"),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1):
|
|
call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
|
|
call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
|
|
add = torch.ops.aten.add.Tensor(call_torchbind, 1); call_torchbind = None
|
|
add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = None
|
|
call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
|
|
call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
|
|
sub = torch.ops.aten.sub.Tensor(call_torchbind_2, 0); call_torchbind_2 = None
|
|
add_2 = torch.ops.aten.add.Tensor(sub, arg1_1); sub = arg1_1 = None
|
|
return (add_2, add_1, arg0_1)
|
|
""",
|
|
)
|
|
# turn off tq type checking in eager execution
|
|
mod.check_tq_is_fake = False
|
|
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
|
|
self.assertEqual(tq.size(), 0)
|
|
self.assertEqual(tq1.size(), 0)
|
|
|
|
def test_identifying_torchbind_ops(self):
|
|
for op in self.torch_bind_ops:
|
|
self.assertTrue(op._has_torchbind_op_overload)
|
|
|
|
for op in [
|
|
torch.ops.aten.add,
|
|
torch.ops.aten.cos,
|
|
]:
|
|
self.assertFalse(op._has_torchbind_op_overload)
|
|
|
|
def test_torchbind_op_register_fallthrough(self):
|
|
TEST_DISPATCH_KEY = torch._C.DispatchKey.AutocastCPU
|
|
TEST_DISPATCH_KEY_STR = "AutocastCPU"
|
|
|
|
for op_packet in self.torch_bind_ops:
|
|
op = op_packet.default
|
|
ns, _ = torch._library.utils.parse_namespace(op_packet._qualified_op_name)
|
|
with torch.library._scoped_library(ns, "FRAGMENT") as lib:
|
|
lib.impl(
|
|
op.name(), torch.library.fallthrough_kernel, TEST_DISPATCH_KEY_STR
|
|
)
|
|
self.assertTrue(
|
|
torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
|
|
op.name(), TEST_DISPATCH_KEY
|
|
)
|
|
)
|
|
|
|
def test_torchbind_op_fallthrough_keys_respects_lib_impl(self):
|
|
TEST_DISPATCH_KEY = torch._C.DispatchKey.AutogradCPU
|
|
TEST_DISPATCH_KEY_STR = "AutogradCPU"
|
|
|
|
tested = 0
|
|
for op_packet in self.torch_bind_ops:
|
|
op = op_packet.default
|
|
ns, _ = torch._library.utils.parse_namespace(op_packet._qualified_op_name)
|
|
if (
|
|
not torch._C._dispatch_has_kernel_for_dispatch_key(
|
|
op.name(), TEST_DISPATCH_KEY
|
|
)
|
|
and TEST_DISPATCH_KEY not in op.py_kernels
|
|
):
|
|
tested += 1
|
|
with torch.library._scoped_library(ns, "FRAGMENT") as lib:
|
|
lib.impl(
|
|
op.name(), lambda *args, **kwargs: args, TEST_DISPATCH_KEY_STR
|
|
)
|
|
self.assertTrue(TEST_DISPATCH_KEY not in op._fallthrough_keys())
|
|
|
|
with torch.library._scoped_library(ns, "FRAGMENT") as lib:
|
|
lib.impl(
|
|
op.name(),
|
|
torch.library.fallthrough_kernel,
|
|
TEST_DISPATCH_KEY_STR,
|
|
)
|
|
self.assertTrue(TEST_DISPATCH_KEY in op._fallthrough_keys())
|
|
self.assertTrue(tested > 0)
|
|
|
|
def test_make_fx_schema_checking_script_object(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, tq, x, foo):
|
|
torch.ops._TorchScriptTesting.queue_push(foo, x.cos())
|
|
return tq
|
|
|
|
class ModelCallByKW(torch.nn.Module):
|
|
def forward(self, tq, x, foo):
|
|
torch.ops._TorchScriptTesting.queue_push(x=x.cos(), foo=foo)
|
|
return tq
|
|
|
|
mod = Model()
|
|
modkw = ModelCallByKW()
|
|
|
|
foo = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
x = torch.ones(3, 3)
|
|
tq = torch.classes._TorchScriptTesting._TensorQueue(
|
|
torch.empty(
|
|
0,
|
|
).fill_(-1)
|
|
)
|
|
ns = "_TorchScriptTesting"
|
|
with torch.library._scoped_library(ns, "FRAGMENT") as lib:
|
|
op = torch.ops._TorchScriptTesting.queue_push
|
|
lib.impl(op.__name__, torch.library.fallthrough_kernel, "AutogradCPU")
|
|
lib.impl(op.__name__, torch.library.fallthrough_kernel, "ADInplaceOrView")
|
|
lib.impl(
|
|
op.__name__,
|
|
torch.library.fallthrough_kernel,
|
|
"PythonTLSSnapshot",
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "is expected to be a FakeScriptObject"
|
|
):
|
|
_ = make_fx(mod, tracing_mode="fake")(tq, x, foo)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "is expected to be a FakeScriptObject"
|
|
):
|
|
_ = make_fx(modkw, tracing_mode="fake")(tq, x, foo)
|
|
|
|
@parametrize("fallthrough_via", ["lib_impl", "py_impl"])
|
|
def test_make_fx_tensor_queue_operators(self, fallthrough_via):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, tq, x):
|
|
with torch.autocast("cuda", dtype=torch.bfloat16):
|
|
torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
|
|
torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
|
|
x_sin = torch.ops._TorchScriptTesting.queue_pop(
|
|
tq
|
|
) - torch.ops._TorchScriptTesting.queue_size(tq)
|
|
x_cos = torch.ops._TorchScriptTesting.queue_pop(
|
|
tq
|
|
) + torch.ops._TorchScriptTesting.queue_size(tq)
|
|
return x_sin, x_cos, tq
|
|
|
|
mod = Model()
|
|
|
|
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
|
|
torch.empty(
|
|
0,
|
|
).fill_(-1)
|
|
)
|
|
tq2 = torch.classes._TorchScriptTesting._TensorQueue(
|
|
torch.empty(
|
|
0,
|
|
).fill_(-1)
|
|
)
|
|
x = torch.ones(2, 3)
|
|
|
|
mod(tq1, x)
|
|
|
|
ops = [
|
|
torch.ops._TorchScriptTesting.queue_push,
|
|
torch.ops._TorchScriptTesting.queue_pop,
|
|
torch.ops._TorchScriptTesting.queue_size,
|
|
]
|
|
if fallthrough_via == "lib_impl":
|
|
ns = "_TorchScriptTesting"
|
|
with torch.library._scoped_library(ns, "FRAGMENT") as lib:
|
|
for op in ops:
|
|
lib.impl(
|
|
op.__name__, torch.library.fallthrough_kernel, "AutocastCUDA"
|
|
)
|
|
|
|
gm = make_fx(mod, tracing_mode="fake")(tq1, x)
|
|
else:
|
|
for op in ops:
|
|
op.default.py_impl(torch._C.DispatchKey.AutocastCUDA)(
|
|
torch.library.fallthrough_kernel
|
|
)
|
|
gm = make_fx(mod, tracing_mode="fake")(tq1, x)
|
|
for op in ops:
|
|
op.default._dispatch_cache.clear()
|
|
del op.default.py_kernels[torch._C.DispatchKey.AutocastCUDA]
|
|
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1):
|
|
cos = torch.ops.aten.cos.default(arg1_1)
|
|
queue_push = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, cos); cos = None
|
|
sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
|
|
queue_push_1 = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, sin); sin = None
|
|
queue_pop = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1)
|
|
queue_size = torch.ops._TorchScriptTesting.queue_size.default(arg0_1)
|
|
sub = torch.ops.aten.sub.Tensor(queue_pop, 1); queue_pop = None
|
|
queue_pop_1 = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1)
|
|
queue_size_1 = torch.ops._TorchScriptTesting.queue_size.default(arg0_1)
|
|
add = torch.ops.aten.add.Tensor(queue_pop_1, 0); queue_pop_1 = None
|
|
return (sub, add, arg0_1)""",
|
|
)
|
|
self._assertEqualSkipScriptObject(gm(tq1, x), mod(tq2, x))
|
|
|
|
def test_aot_export_tensor_queue_operators(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, tq, x):
|
|
torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
|
|
torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
|
|
x_sin = torch.ops._TorchScriptTesting.queue_pop(
|
|
tq
|
|
) - torch.ops._TorchScriptTesting.queue_size(tq)
|
|
x_cos = torch.ops._TorchScriptTesting.queue_pop(
|
|
tq
|
|
) + torch.ops._TorchScriptTesting.queue_size(tq)
|
|
return x_sin, x_cos, tq
|
|
|
|
mod = Model()
|
|
|
|
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
|
|
torch.empty(
|
|
0,
|
|
).fill_(-1)
|
|
)
|
|
x = torch.ones(2, 3)
|
|
|
|
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
|
|
fake_tq1 = torch._library.fake_class_registry.to_fake_obj(fake_mode, tq1)
|
|
fake_x = fake_mode.from_tensor(x)
|
|
gm = aot_export_module(mod, (fake_tq1, fake_x), trace_joint=False)[0]
|
|
|
|
# inputs: token, tq, x
|
|
# return: token, x_sin, x_cos, tq
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1, arg2_1):
|
|
cos = torch.ops.aten.cos.default(arg2_1)
|
|
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops._TorchScriptTesting.queue_push.default, arg1_1, cos); arg0_1 = cos = None
|
|
getitem = with_effects[0]; with_effects = None
|
|
sin = torch.ops.aten.sin.default(arg2_1); arg2_1 = None
|
|
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.queue_push.default, arg1_1, sin); getitem = sin = None
|
|
getitem_2 = with_effects_1[0]; with_effects_1 = None
|
|
with_effects_2 = torch._higher_order_ops.effects.with_effects(getitem_2, torch.ops._TorchScriptTesting.queue_pop.default, arg1_1); getitem_2 = None
|
|
getitem_4 = with_effects_2[0]
|
|
getitem_5 = with_effects_2[1]; with_effects_2 = None
|
|
with_effects_3 = torch._higher_order_ops.effects.with_effects(getitem_4, torch.ops._TorchScriptTesting.queue_size.default, arg1_1); getitem_4 = None
|
|
getitem_6 = with_effects_3[0]; with_effects_3 = None
|
|
sub = torch.ops.aten.sub.Tensor(getitem_5, 1); getitem_5 = None
|
|
with_effects_4 = torch._higher_order_ops.effects.with_effects(getitem_6, torch.ops._TorchScriptTesting.queue_pop.default, arg1_1); getitem_6 = None
|
|
getitem_8 = with_effects_4[0]
|
|
getitem_9 = with_effects_4[1]; with_effects_4 = None
|
|
with_effects_5 = torch._higher_order_ops.effects.with_effects(getitem_8, torch.ops._TorchScriptTesting.queue_size.default, arg1_1); getitem_8 = None
|
|
getitem_10 = with_effects_5[0]; with_effects_5 = None
|
|
add = torch.ops.aten.add.Tensor(getitem_9, 0); getitem_9 = None
|
|
return (getitem_10, sub, add, arg1_1)""", # noqa: B950
|
|
)
|
|
|
|
|
|
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
|
|
class TestRegisterFakeClass(TestCase):
|
|
def setUp(self):
|
|
init_torchbind_implementations()
|
|
|
|
def tearDown(self):
|
|
torch._library.fake_class_registry.global_fake_class_registry.clear()
|
|
|
|
def test_register_fake_class_no_torch_bind_class(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Tried to instantiate class"):
|
|
|
|
@torch._library.register_fake_class("_TorchScriptTesting::NOT_A_VALID_NAME")
|
|
class Invalid:
|
|
pass
|
|
|
|
def test_register_fake_class_no_from_real(self):
|
|
with self.assertRaisesRegex(RuntimeError, "define a classmethod from_real"):
|
|
|
|
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
|
class InvalidFakeFoo:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def test_register_fake_class_from_real_not_classmethod(self):
|
|
with self.assertRaisesRegex(RuntimeError, "is not a classmethod"):
|
|
|
|
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
|
class FakeFoo:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
def from_real(self, foo_obj):
|
|
x, y = foo_obj.__getstate__()
|
|
return FakeFoo(x, y)
|
|
|
|
def test_register_fake_class_valid(self):
|
|
class FakeFoo:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
@classmethod
|
|
def from_real(cls, foo_obj):
|
|
x, y = foo_obj.__getstate__()
|
|
return cls(x, y)
|
|
|
|
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
|
|
|
|
def test_register_fake_class_duplicate_registration(self):
|
|
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
|
class FakeFoo:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
@classmethod
|
|
def from_real(cls, foo_obj):
|
|
x, y = foo_obj.__getstate__()
|
|
return cls(x, y)
|
|
|
|
with self.assertWarnsRegex(UserWarning, "already registered"):
|
|
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
|
|
|
|
|
|
instantiate_parametrized_tests(TestExportTorchbind)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|