[dynamo] support torchbind object input (#124978)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124978
Approved by: https://github.com/jansel
This commit is contained in:
ydwu4 2024-05-06 15:03:02 -07:00 committed by PyTorch MergeBot
parent c165a8e71d
commit 461ffaaaf3
15 changed files with 766 additions and 110 deletions

View File

@ -59,6 +59,10 @@ struct Foo : torch::CustomClassHolder {
bool eq(c10::intrusive_ptr<Foo> other) { bool eq(c10::intrusive_ptr<Foo> other) {
return this->x == other->x && this->y == other->y; return this->x == other->x && this->y == other->y;
} }
std::tuple<std::tuple<std::string, int64_t>, std::tuple<std::string, int64_t>>
__obj_flatten__() {
return std::tuple(std::tuple("x", this->x), std::tuple("y", this->y));
}
}; };
struct _StaticMethod : torch::CustomClassHolder { struct _StaticMethod : torch::CustomClassHolder {
@ -199,6 +203,10 @@ struct TensorQueue : torch::CustomClassHolder {
return raw_queue; return raw_queue;
} }
std::tuple<std::tuple<std::string, std::vector<at::Tensor>>> __obj_flatten__() {
return std::tuple(std::tuple("queue", this->get_raw_queue()));
}
private: private:
std::deque<at::Tensor> queue_; std::deque<at::Tensor> queue_;
std::mutex mutex_; std::mutex mutex_;
@ -370,6 +378,10 @@ struct ContainsTensor : public torch::CustomClassHolder {
return t_; return t_;
} }
std::tuple<std::tuple<std::string, at::Tensor>> __obj_flatten__() {
return std::tuple(std::tuple("t", this->t_));
}
at::Tensor t_; at::Tensor t_;
}; };
@ -417,6 +429,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
.def("add_tensor", &Foo::add_tensor) .def("add_tensor", &Foo::add_tensor)
.def("__eq__", &Foo::eq) .def("__eq__", &Foo::eq)
.def("combine", &Foo::combine) .def("combine", &Foo::combine)
.def("__obj_flatten__", &Foo::__obj_flatten__)
.def_pickle( .def_pickle(
[](c10::intrusive_ptr<Foo> self) { // __getstate__ [](c10::intrusive_ptr<Foo> self) { // __getstate__
return std::vector<int64_t>{self->x, self->y}; return std::vector<int64_t>{self->x, self->y};
@ -424,6 +437,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
[](std::vector<int64_t> state) { // __setstate__ [](std::vector<int64_t> state) { // __setstate__
return c10::make_intrusive<Foo>(state[0], state[1]); return c10::make_intrusive<Foo>(state[0], state[1]);
}); });
m.def( m.def(
"takes_foo(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor"); "takes_foo(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
m.def( m.def(
@ -551,6 +565,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
m.class_<ContainsTensor>("_ContainsTensor") m.class_<ContainsTensor>("_ContainsTensor")
.def(torch::init<at::Tensor>()) .def(torch::init<at::Tensor>())
.def("get", &ContainsTensor::get) .def("get", &ContainsTensor::get)
.def("__obj_flatten__", &ContainsTensor::__obj_flatten__)
.def_pickle( .def_pickle(
// __getstate__ // __getstate__
[](const c10::intrusive_ptr<ContainsTensor>& self) -> at::Tensor { [](const c10::intrusive_ptr<ContainsTensor>& self) -> at::Tensor {
@ -568,6 +583,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
.def("size", &TensorQueue::size) .def("size", &TensorQueue::size)
.def("clone_queue", &TensorQueue::clone_queue) .def("clone_queue", &TensorQueue::clone_queue)
.def("get_raw_queue", &TensorQueue::get_raw_queue) .def("get_raw_queue", &TensorQueue::get_raw_queue)
.def("__obj_flatten__", &TensorQueue::__obj_flatten__)
.def_pickle( .def_pickle(
// __getstate__ // __getstate__
[](const c10::intrusive_ptr<TensorQueue>& self) [](const c10::intrusive_ptr<TensorQueue>& self)

View File

@ -3,8 +3,11 @@
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
from torch._dynamo.testing import EagerAndRecordGraphs
from torch._functorch.aot_autograd import aot_export_module from torch._functorch.aot_autograd import aot_export_module
from torch._higher_order_ops.torchbind import enable_torchbind_tracing from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch._higher_order_ops.wrap import wrap
from torch._library.fake_class_registry import FakeScriptObject from torch._library.fake_class_registry import FakeScriptObject
from torch.export import export from torch.export import export
from torch.export._trace import _export from torch.export._trace import _export
@ -16,7 +19,39 @@ from torch.testing._internal.common_utils import (
skipIfTorchDynamo, skipIfTorchDynamo,
TestCase, TestCase,
) )
from torch.testing._internal.torchbind_impls import init_torchbind_implementations from torch.testing._internal.torchbind_impls import (
_empty_tensor_queue,
init_torchbind_implementations,
)
def _assertEqualSkipScriptObject(test_case, exp, actual):
flat_exp = pytree.tree_leaves(exp)
flat_actual = pytree.tree_leaves(actual)
test_case.assertEqual(len(flat_exp), len(flat_actual))
for a, b in zip(flat_exp, flat_actual):
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
continue
test_case.assertEqual(a, b)
def _check_script_obj_equal(test_case, a: torch.ScriptObject, b: torch.ScriptObject):
return test_case.assertEqual(
a._type().qualified_name(), b._type().qualified_name()
) and test_case.assertEqual(a.__obj_flatten__(), b.__obj_flatten__())
def _assertEqualScriptObject(
test_case, exp, actual, check_obj_eq=_check_script_obj_equal
):
flat_exp = pytree.tree_leaves(exp)
flat_actual = pytree.tree_leaves(actual)
test_case.assertEqual(len(flat_exp), len(flat_actual))
for a, b in zip(flat_exp, flat_actual):
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
check_obj_eq(test_case, a, b)
else:
test_case.assertEqual(a, b)
@skipIfTorchDynamo("torchbind not supported with dynamo yet") @skipIfTorchDynamo("torchbind not supported with dynamo yet")
@ -37,9 +72,8 @@ class TestExportTorchbind(TestCase):
self.y = y self.y = y
@classmethod @classmethod
def from_real(cls, foo): def __obj_unflatten__(cls, flattend_foo):
(x, y), _ = foo.__getstate__() return cls(**dict(flattend_foo))
return cls(x, y)
def add_tensor(self, z): def add_tensor(self, z):
test.foo_add_tensor_counter += 1 test.foo_add_tensor_counter += 1
@ -47,14 +81,12 @@ class TestExportTorchbind(TestCase):
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
class FakeTensorQueue: class FakeTensorQueue:
def __init__(self, q): def __init__(self, queue):
self.queue = q self.queue = queue
@classmethod @classmethod
def from_real(cls, real_tq): def __obj_unflatten__(cls, flattened_ctx):
ctx = torch.library.get_ctx() return cls(**dict(flattened_ctx))
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
return cls(fake_queue)
def push(self, x): def push(self, x):
test.tq_push_counter += 1 test.tq_push_counter += 1
@ -89,15 +121,6 @@ class TestExportTorchbind(TestCase):
"_TorchScriptTesting::_TensorQueue" "_TorchScriptTesting::_TensorQueue"
) )
def _assertEqualSkipScriptObject(self, exp, actual):
flat_exp = pytree.tree_leaves(exp)
flat_actual = pytree.tree_leaves(actual)
self.assertEqual(len(flat_exp), len(flat_actual))
for a, b in zip(flat_exp, flat_actual):
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
continue
self.assertEqual(a, b)
def _test_export_same_as_eager( def _test_export_same_as_eager(
self, f, args, kwargs=None, strict=True, pre_dispatch=False self, f, args, kwargs=None, strict=True, pre_dispatch=False
): ):
@ -532,7 +555,7 @@ def forward(self, arg0_1, arg1_1):
""", """,
) )
mod.check_tq_is_fake = False mod.check_tq_is_fake = False
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x)) _assertEqualSkipScriptObject(self, gm(tq, x), mod(tq1, x))
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"]) @parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
def test_make_fx_tensor_queue_methods_fakify_internal_states( def test_make_fx_tensor_queue_methods_fakify_internal_states(
@ -592,7 +615,7 @@ def forward(self, arg0_1, arg1_1):
) )
# turn off tq type checking in eager execution # turn off tq type checking in eager execution
mod.check_tq_is_fake = False mod.check_tq_is_fake = False
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x)) _assertEqualSkipScriptObject(self, gm(tq, x), mod(tq1, x))
self.assertEqual(tq.size(), 0) self.assertEqual(tq.size(), 0)
self.assertEqual(tq1.size(), 0) self.assertEqual(tq1.size(), 0)
@ -769,7 +792,7 @@ def forward(self, arg0_1, arg1_1):
add = torch.ops.aten.add.Tensor(queue_pop_1, 0); queue_pop_1 = None add = torch.ops.aten.add.Tensor(queue_pop_1, 0); queue_pop_1 = None
return (sub, add, arg0_1)""", return (sub, add, arg0_1)""",
) )
self._assertEqualSkipScriptObject(gm(tq1, x), mod(tq2, x)) _assertEqualSkipScriptObject(self, gm(tq1, x), mod(tq2, x))
def test_aot_export_tensor_queue_operators(self): def test_aot_export_tensor_queue_operators(self):
class Model(torch.nn.Module): class Model(torch.nn.Module):
@ -829,6 +852,377 @@ def forward(self, arg0_1, arg1_1, arg2_1):
) )
class TestCompileTorchbind(TestCase):
def setUp(self):
init_torchbind_implementations()
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
class FakeTensorQueue:
def __init__(self, queue):
self.queue = queue
@classmethod
def __obj_unflatten__(cls, flattened_ctx):
return cls(**dict(flattened_ctx))
def push(self, x):
self.queue.append(x)
def pop(self):
return self.queue.pop(0)
def size(self):
return len(self.queue)
torch._dynamo.reset()
def tearDown(self):
torch._dynamo.reset()
def test_compile_script_object_input(self):
backend = EagerAndRecordGraphs()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.check_tq_is_fake = True
def forward(self, tq, x):
tq.push(x.cos())
tq.push(x.sin())
x_sin = tq.pop() - tq.size()
return x_sin, tq
mod = Model()
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq2 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq3 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq4 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
x = torch.randn(2, 3)
ret = torch.compile(mod, backend=backend)(tq1, x)
eager_ret = mod(tq2, x)
_assertEqualSkipScriptObject(self, ret, eager_ret)
self.assertEqual(ret[1].size(), eager_ret[1].size())
self.assertEqual(ret[1].pop(), eager_ret[1].pop())
# Note that dynamo captured graph
# does not return L_tq_ as output. This is because it's able
# to detect that L_tq_ is an input therefore don't return
# it as graph output. Related logic is in dynamo/codegen.py
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, L_tq_ : torch.ScriptObject, L_x_ : torch.Tensor):
l_tq_ = L_tq_
l_x_ = L_x_
cos = l_x_.cos()
call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', cos); cos = None
sin = l_x_.sin(); l_x_ = None
call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', sin); sin = None
call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'pop')
call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None
x_sin = call_torchbind_2 - 1; call_torchbind_2 = None
return (x_sin,)""",
)
def test_compile_script_object_input_guards(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.check_tq_is_fake = True
def forward(self, tq, x):
tq.push(x.cos())
tq.push(x.sin())
x_sin = tq.pop() - tq.size()
return x_sin, tq
mod = Model()
cnt = torch._dynamo.testing.CompileCounter()
x = torch.randn(2, 3)
tq1 = _empty_tensor_queue()
torch.compile(mod, backend=cnt)(tq1, x)
self.assertEqual(cnt.frame_count, 1)
tq2 = _empty_tensor_queue()
for _ in range(10):
tq2.push(torch.randn(4, 5, requires_grad=False))
torch.compile(mod, backend=cnt)(tq2, x)
# Queue length change causes re-compile
self.assertEqual(cnt.frame_count, 2)
tq3 = _empty_tensor_queue()
tq3.push(torch.randn(2, 3, requires_grad=False))
torch.compile(mod, backend=cnt)(tq3, x)
# Tensor in queue changes shape causes re-compile
self.assertEqual(cnt.frame_count, 3)
tq4 = _empty_tensor_queue()
tq4.push(torch.randn(2, 3, requires_grad=False))
torch.compile(mod, backend=cnt)(tq4, x)
# No recompile
self.assertEqual(cnt.frame_count, 3)
tq5 = _empty_tensor_queue()
tq5.push(torch.randn(2, 3, requires_grad=True))
torch.compile(mod, backend=cnt)(tq5, x)
# Tensor in queue changes dispatch key causes re-compile
self.assertEqual(cnt.frame_count, 4)
tq6 = _empty_tensor_queue()
tq6.push(torch.randn(2, 3, requires_grad=True, dtype=torch.float64))
torch.compile(mod, backend=cnt)(tq6, x)
# Tensor in queue changes dtype causes re-compile
self.assertEqual(cnt.frame_count, 5)
def test_compile_script_object_input_automatic_dynamic_shape(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.check_tq_is_fake = True
def forward(self, tq, x):
tq.push(x.cos())
tq.push(x.sin())
x_sin = tq.pop() - tq.size()
return x_sin, tq
mod = Model()
cnt = torch._dynamo.testing.CompileCounter()
x = torch.randn(2, 3)
tq1 = _empty_tensor_queue()
tq1.push(torch.randn(2, 3, requires_grad=False))
torch.compile(mod, backend=cnt)(tq1, x)
self.assertEqual(cnt.frame_count, 1)
tq2 = _empty_tensor_queue()
# make first tensor's secon dim dynamic
tq2.push(torch.randn(2, 4, requires_grad=False))
torch.compile(mod, backend=cnt)(tq2, x)
self.assertEqual(cnt.frame_count, 2)
tq3 = _empty_tensor_queue()
tq3.push(torch.randn(2, 5, requires_grad=False))
# should have no-recompilation
torch.compile(mod, backend=cnt)(tq3, x)
self.assertEqual(cnt.frame_count, 2)
def test_compile_error_on_input_aliasing_contents(self):
backend = EagerAndRecordGraphs()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.check_tq_is_fake = True
def forward(self, tq, x):
return x.sin(), tq.pop().cos()
x = torch.randn(2, 3)
mod = Model()
tq1 = _empty_tensor_queue()
tq1.push(x)
with self.assertRaisesRegex(RuntimeError, "is alising"):
torch.compile(mod, backend=backend)(tq1, x)
def test_compile_error_on_script_obj_setattr(self):
def setattr_f(tq):
tq.a = 1
return tq
with self.assertRaisesRegex(
RuntimeError, "call method __setattr__ on script object is not safe"
):
torch.compile(setattr_f, backend="eager")(_empty_tensor_queue())
def test_compile_error_on_script_obj_missing_attr(self):
def setattr_f(tq):
return tq._not_defined_attr
with self.assertRaisesRegex(
RuntimeError, "doesn't define method _not_defined_attr"
):
torch.compile(setattr_f, backend="eager")(_empty_tensor_queue())
def test_compile_body_aliasing_contents(self):
backend = EagerAndRecordGraphs()
def f(tq, x):
x1 = x.view(-1)
x2 = x.permute(1, 0)
tq.push(x1)
tq.push(x2)
return x1 - tq.size(), x2 + tq.size(), tq
x = torch.randn(2, 3)
_assertEqualScriptObject(
self,
f(_empty_tensor_queue(), x),
torch.compile(f, backend=backend)(_empty_tensor_queue(), x),
)
if not torch._dynamo.is_compiling():
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject):
l_x_ = L_x_
l_tq_ = L_tq_
x1 = l_x_.view(-1)
x2 = l_x_.permute(1, 0); l_x_ = None
call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', x1)
call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', x2)
call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'size')
sub = x1 - 2; x1 = None
call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None
add = x2 + 2; x2 = None
return (sub, add)""",
)
def test_compile_error_on_non_fakified_method(self):
backend = EagerAndRecordGraphs()
def f(tq, x):
x1 = x.view(-1)
x2 = x.permute(1, 0)
tq.push(x1)
tq.push(x2)
# though real tensor queue implemented a method clone_queue,
# The fakified version doesn't.
flat_obj = tq.clone_queue()
return flat_obj
x = torch.randn(2, 3)
with self.assertRaisesRegex(
RuntimeError, "FakeScriptObject doesn't define method"
):
torch.compile(f, backend=backend)(_empty_tensor_queue(), x)
def test_compile_obj_as_hop_input(self):
def f(tq, x):
def fn(tq, x):
tq.push(x)
return x.sin()
return wrap(fn, tq, x)
x = torch.randn(2, 3)
_assertEqualScriptObject(
self,
f(_empty_tensor_queue(), x),
torch.compile(f, backend="eager")(_empty_tensor_queue(), x),
)
def test_compile_obj_closure(self):
def f(x):
def inner_f(x):
tq.push(x.sin())
inner_f(x)
return tq.pop(), tq
opt_f = torch.compile(f, backend="eager")
tq = _empty_tensor_queue()
x = torch.randn(3, 2)
_assertEqualScriptObject(self, f(x), opt_f(x))
def test_compile_global_obj(self):
global _TENSOR_QUEUE_GLOBAL_TEST
_TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue()
def f(x):
_TENSOR_QUEUE_GLOBAL_TEST.push(x.sin())
return _TENSOR_QUEUE_GLOBAL_TEST.pop(), _TENSOR_QUEUE_GLOBAL_TEST
opt_f = torch.compile(f, backend="eager")
x = torch.randn(3, 2)
eager_ret = f(x)
opt_ret = opt_f(x)
_assertEqualScriptObject(self, eager_ret, opt_ret)
def test_compile_obj_graph_breaks(self):
cnt = torch._dynamo.testing.CompileCounter()
def f(tq, x):
tq.push(x.sin())
tq.push(x.sin())
torch._dynamo.graph_break()
tq.pop()
torch._dynamo.graph_break()
tq.push(x.cos() + tq.size())
torch._dynamo.graph_break()
tq.push(x.cos() - tq.size())
return x, tq.pop(), tq
opt_f = torch.compile(f, backend=cnt)
x = torch.randn(3, 2)
_assertEqualScriptObject(
self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x)
)
self.assertEqual(cnt.frame_count, 4)
def test_compile_obj_attributes(self):
backend = EagerAndRecordGraphs()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.tq = _empty_tensor_queue()
def forward(self, x):
self.tq.push(x)
return self.tq.pop()
x = torch.randn(2, 3)
opt_f = torch.compile(Model(), backend=backend)
_assertEqualScriptObject(self, Model()(x), opt_f(x))
self.assertEqual(len(backend.graphs), 1)
# lifted as input. In the future, we would want to cosolidate this
# with non-strict behavior, where they're set as attributes.
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, L_self_tq : torch.ScriptObject, L_x_ : torch.Tensor):
l_self_tq = L_self_tq
l_x_ = L_x_
call_torchbind = torch.ops.higher_order.call_torchbind(l_self_tq, 'push', l_x_); l_x_ = None
call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_self_tq, 'pop'); l_self_tq = None
return (call_torchbind_1,)""",
)
def test_compile_obj_torchbind_op(self):
def f(tq, x):
torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
torch.ops._TorchScriptTesting.queue_push(tq, x.cos() + 1)
torch.ops._TorchScriptTesting.queue_pop(tq)
torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
return tq.pop(), tq.pop() + tq.size(), tq
opt_f = torch.compile(f, backend="eager")
x = torch.randn(2)
_assertEqualScriptObject(
self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x)
)
@skipIfTorchDynamo("torchbind not supported with dynamo yet") @skipIfTorchDynamo("torchbind not supported with dynamo yet")
class TestRegisterFakeClass(TestCase): class TestRegisterFakeClass(TestCase):
def setUp(self): def setUp(self):
@ -845,7 +1239,9 @@ class TestRegisterFakeClass(TestCase):
pass pass
def test_register_fake_class_no_from_real(self): def test_register_fake_class_no_from_real(self):
with self.assertRaisesRegex(RuntimeError, "define a classmethod from_real"): with self.assertRaisesRegex(
RuntimeError, "define a classmethod __obj_unflatten__"
):
@torch._library.register_fake_class("_TorchScriptTesting::_Foo") @torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class InvalidFakeFoo: class InvalidFakeFoo:
@ -861,9 +1257,8 @@ class TestRegisterFakeClass(TestCase):
self.x = x self.x = x
self.y = y self.y = y
def from_real(self, foo_obj): def __obj_unflatten__(cls, flattend_foo): # noqa: B902
x, y = foo_obj.__getstate__() return cls(**dict(flattend_foo))
return FakeFoo(x, y)
def test_register_fake_class_valid(self): def test_register_fake_class_valid(self):
class FakeFoo: class FakeFoo:
@ -872,9 +1267,8 @@ class TestRegisterFakeClass(TestCase):
self.y = y self.y = y
@classmethod @classmethod
def from_real(cls, foo_obj): def __obj_unflatten__(cls, flattend_foo):
x, y = foo_obj.__getstate__() return cls(**dict(flattend_foo))
return cls(x, y)
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo) torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)

View File

@ -573,9 +573,8 @@ class TestUnflatten(TestCase):
self.y = y self.y = y
@classmethod @classmethod
def from_real(cls, foo): def __obj_unflatten__(cls, flat_ctx):
(x, y), _ = foo.__getstate__() return cls(**dict(flat_ctx))
return cls(x, y)
def add_tensor(self, z): def add_tensor(self, z):
return (self.x + self.y) * z return (self.x + self.y) * z

View File

@ -17,34 +17,21 @@ from torch.testing import FileCheck
from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM80OrLater from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM80OrLater
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS, IS_WINDOWS,
run_tests, run_tests,
TEST_CUDA, TEST_CUDA,
TEST_WITH_ROCM, TEST_WITH_ROCM,
TestCase, TestCase,
) )
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support") @unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
class TestWithEffects(TestCase): class TestWithEffects(TestCase):
def setUp(self): def setUp(self):
if IS_MACOS: init_torchbind_implementations()
raise unittest.SkipTest("non-portable load_library call used in test")
elif IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library(
"//caffe2/test/cpp/jit:test_custom_class_registrations"
)
elif IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
else:
lib_file_path = find_library_location("libtorchbind_test.so")
torch.ops.load_library(str(lib_file_path))
def test_print(self): def test_print(self):
class M(torch.nn.Module): class M(torch.nn.Module):

View File

@ -171,6 +171,10 @@ class UserStopIteration(TorchDynamoException):
self.value = None self.value = None
class UnsafeScriptObjectError(TorchDynamoException):
pass
class UncapturedHigherOrderOpError(TorchDynamoException): class UncapturedHigherOrderOpError(TorchDynamoException):
pass pass

View File

@ -41,6 +41,7 @@ except ModuleNotFoundError:
import torch import torch
import torch.utils._device import torch.utils._device
from torch._dynamo.source import ( from torch._dynamo.source import (
is_from_flatten_script_object_source,
is_from_local_source, is_from_local_source,
is_from_optimizer_source, is_from_optimizer_source,
TensorProperty, TensorProperty,
@ -72,6 +73,7 @@ from .source import (
ChainedSource, ChainedSource,
ConstDictKeySource, ConstDictKeySource,
DefaultsSource, DefaultsSource,
FlattenScriptObjectSource,
FSDPNNModuleSource, FSDPNNModuleSource,
GetItemSource, GetItemSource,
GlobalSource, GlobalSource,
@ -84,6 +86,7 @@ from .source import (
NumpyTensorSource, NumpyTensorSource,
ODictGetItemSource, ODictGetItemSource,
OptimizerSource, OptimizerSource,
ScriptObjectQualifiedNameSource,
ShapeEnvSource, ShapeEnvSource,
TupleIteratorGetItemSource, TupleIteratorGetItemSource,
TypeSource, TypeSource,
@ -957,6 +960,22 @@ class GuardBuilder(GuardBuilderBase):
example_value=example_value, example_value=example_value,
guard_manager_enum=guard_manager_enum, guard_manager_enum=guard_manager_enum,
) )
elif istype(source, FlattenScriptObjectSource):
assert base_guard_manager # to make mypy happy
return base_guard_manager.lambda_manager(
python_lambda=lambda x: x.__obj_flatten__(),
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, ScriptObjectQualifiedNameSource):
assert base_guard_manager # to make mypy happy
return base_guard_manager.lambda_manager(
python_lambda=lambda x: x._type().qualified_name(),
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, TupleIteratorGetItemSource): elif istype(source, TupleIteratorGetItemSource):
assert base_guard_manager # to make mypy happy assert base_guard_manager # to make mypy happy
return base_guard_manager.tuple_iterator_getitem_manager( return base_guard_manager.tuple_iterator_getitem_manager(
@ -2602,6 +2621,14 @@ def make_dupe_guard(obj_source, dupe_source):
if dupe_source and dupe_source != obj_source: if dupe_source and dupe_source != obj_source:
ser_source_is_local = is_from_local_source(dupe_source) ser_source_is_local = is_from_local_source(dupe_source)
source_is_local = is_from_local_source(obj_source) source_is_local = is_from_local_source(obj_source)
if is_from_flatten_script_object_source(
dupe_source
) or is_from_flatten_script_object_source(obj_source):
raise exc.UnsafeScriptObjectError(
f"{obj_source.name()} is alising {dupe_source.name()}. This is not supported."
f" Please do a clone for corresponding input."
)
# Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently # Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently
# reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here, # reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here,
# so maybe we should do this refactor before we land this... # so maybe we should do this refactor before we land this...

View File

@ -1434,7 +1434,11 @@ class OutputGraph:
self.remove_node(node) self.remove_node(node)
self.real_value_cache.pop(node, None) self.real_value_cache.pop(node, None)
used_symbols = set() used_symbols: Set[sympy.Symbol] = set()
def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]):
used_symbols |= free_symbols(fake)
recheck_placeholders = [] recheck_placeholders = []
for node in self.placeholders: for node in self.placeholders:
binds_symbol = placeholder_binds_symbol(node) is not None binds_symbol = placeholder_binds_symbol(node) is not None
@ -1452,10 +1456,22 @@ class OutputGraph:
arg = node.meta["grapharg"] arg = node.meta["grapharg"]
if isinstance(arg, BackwardStateGraphArg): if isinstance(arg, BackwardStateGraphArg):
continue continue
if isinstance(node.meta["grapharg"].example, torch.ScriptObject):
real_script_obj = node.meta["grapharg"].example
fake_script_obj = node.meta["grapharg"].example_strong_ref
flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined]
for attr in flat_dict.keys():
fake_attr_val = getattr(fake_script_obj.wrapped_obj, attr)
pytree.tree_map_only(
(torch.SymInt, torch.Tensor),
lambda t: update_used_symbols(used_symbols, t),
fake_attr_val,
)
continue
fake = ( fake = (
arg.fake_tensor if arg.fake_tensor is not None else arg.example arg.fake_tensor if arg.fake_tensor is not None else arg.example
) )
used_symbols |= free_symbols(fake) update_used_symbols(used_symbols, fake)
# After removing unused graphargs, prune unused binds_symbol # After removing unused graphargs, prune unused binds_symbol
for node in recheck_placeholders: for node in recheck_placeholders:

View File

@ -299,6 +299,36 @@ class ConvertIntSource(ChainedSource):
return f"cast_symbool_to_symint_guardless({self.base.name()})" return f"cast_symbool_to_symint_guardless({self.base.name()})"
@dataclasses.dataclass(frozen=True)
class FlattenScriptObjectSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
self.base.reconstruct(codegen)
def guard_source(self):
return self.base.guard_source()
def name(self):
return f"{self.base.name()}.__obj_flatten__()"
@dataclasses.dataclass(frozen=True)
class ScriptObjectQualifiedNameSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
self.base.reconstruct(codegen)
def guard_source(self):
return self.base.guard_source()
def name(self):
return f"{self.base.name()}._type().qualified_name()"
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class DefaultsSource(ChainedSource): class DefaultsSource(ChainedSource):
idx_key: Union[int, str] idx_key: Union[int, str]
@ -563,6 +593,14 @@ def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
return True return True
def is_from_flatten_script_object_source(source: Source):
if isinstance(source, FlattenScriptObjectSource):
return True
elif isinstance(source, ChainedSource):
return is_from_flatten_script_object_source(source.base)
return False
def is_from_optimizer_source(source: Source): def is_from_optimizer_source(source: Source):
if isinstance(source, OptimizerSource): if isinstance(source, OptimizerSource):
return True return True

View File

@ -13,7 +13,7 @@ import operator
import re import re
import sys import sys
import types import types
from typing import List, NamedTuple, Optional, Union from typing import Any, List, NamedTuple, Optional, Union
from torch.utils._sympy.value_ranges import ValueRanges from torch.utils._sympy.value_ranges import ValueRanges
@ -26,6 +26,7 @@ import torch
from torch import SymInt from torch import SymInt
from torch._guards import GuardSource, TracingContext from torch._guards import GuardSource, TracingContext
from torch._higher_order_ops.torchbind import call_torchbind
from torch._ops import HigherOrderOperator from torch._ops import HigherOrderOperator
from torch._streambase import _EventBase, _StreamBase from torch._streambase import _EventBase, _StreamBase
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
@ -154,6 +155,7 @@ from .misc import (
) )
from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
from .optimizer import OptimizerVariable from .optimizer import OptimizerVariable
from .script_object import TorchScriptObjectVariable
from .sdpa import SDPAParamsVariable from .sdpa import SDPAParamsVariable
from .tensor import ( from .tensor import (
@ -904,13 +906,65 @@ class VariableBuilder:
user_cls_source=AttrSource(self.source, "__class__"), user_cls_source=AttrSource(self.source, "__class__"),
), ),
) )
elif TorchScriptObjectVariable.is_matching_cls(type(value)):
from ..source import (
FlattenScriptObjectSource,
ScriptObjectQualifiedNameSource,
)
# This exists to allow a smoother transition.
# The implications are:
# The script objects won't be tracked as proxies.
# Methods on these objects won't show up in the graph.
# The original script object might be mutated.
if not hasattr(value, "__obj_flatten__"):
return self.wrap_user_defined(value)
# Install the guards on the fully qualified name of the script object
LazyVariableTracker.realize_all(
VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))(
value._type().qualified_name() # type: ignore[attr-defined]
)
)
# Install the guards on the content of the script object by setting the source
# to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
LazyVariableTracker.realize_all(
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
value.__obj_flatten__()
)
)
fake_script_obj = torch._library.fake_class_registry.to_fake_obj(
self.tx.output.fake_mode, value
)
proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(value),
source=self.source,
)
# setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default
# seting example to be real value because these example values will be used
# as example_inputs for user compiler.
proxy.node.meta["grapharg"] = GraphArg(
self.source, value, False, None, False, fake_script_obj
)
return TorchScriptObjectVariable.create(
proxy,
fake_script_obj,
source=self.source,
)
else: else:
self.install_guards(GuardBuilder.TYPE_MATCH) return self.wrap_user_defined(value)
result = UserDefinedObjectVariable(value, source=self.source)
if not SideEffects.cls_supports_mutation_side_effects(type(value)): def wrap_user_defined(self, value: Any):
# don't allow STORE_ATTR mutation with custom __setattr__ self.install_guards(GuardBuilder.TYPE_MATCH)
return result result = UserDefinedObjectVariable(value, source=self.source)
return self.tx.output.side_effects.track_object_existing(value, result) if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__
return result
return self.tx.output.side_effects.track_object_existing(value, result)
def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
if config.specialize_int and type(value) is torch.Size: if config.specialize_int and type(value) is torch.Size:
@ -1782,6 +1836,12 @@ def wrap_fx_proxy_cls(
]: ]:
set_example_value(proxy.node, example_value) set_example_value(proxy.node, example_value)
return ConstantVariable.create(example_value, **options) return ConstantVariable.create(example_value, **options)
elif (
isinstance(example_value, (int, float, bool))
and proxy.node.target is call_torchbind
):
set_example_value(proxy.node, example_value)
return ConstantVariable.create(example_value, **options)
else: else:
unimplemented( unimplemented(
"torch.* op returned non-Tensor " "torch.* op returned non-Tensor "

View File

@ -549,6 +549,8 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
return StrictModeHigherOrderVariable(value, source, **kwargs) return StrictModeHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "associative_scan": elif value.__name__ == "associative_scan":
return AssociativeScanHigherOrderVariable(value, source, **kwargs) return AssociativeScanHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "call_torchbind":
return CallTorchbindHigherOrderVariable(value, source, **kwargs)
else: else:
unimplemented(f"HigherOrderOperator {value.__name__}") unimplemented(f"HigherOrderOperator {value.__name__}")
@ -769,6 +771,34 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
) )
class CallTorchbindHigherOrderVariable(TorchHigherOrderOperatorVariable):
def __init__(self, hop, source, script_obj_var, method_name):
super().__init__(hop, source)
self.script_obj_var = script_obj_var
self.method_name = method_name
def call_function(
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
from .builder import wrap_fx_proxy
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
args_proxy = [arg.as_proxy() for arg in args]
kwargs_proxy = {k: v.as_proxy() for k, v in kwargs.items()}
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=tuple(
[self.script_obj_var.as_proxy(), self.method_name] + args_proxy
),
kwargs=kwargs_proxy,
),
)
class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable): class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
@raise_hard_error_if_graph_break( @raise_hard_error_if_graph_break(
reason="while_loop doesn't work unless it is captured completely with torch.compile." reason="while_loop doesn't work unless it is captured completely with torch.compile."

View File

@ -0,0 +1,80 @@
import functools
from typing import Dict
import torch
from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported
from .base import VariableTracker
from .user_defined import UserDefinedObjectVariable
def _raise_hard_error_if_graph_break(reason):
def deco(fn):
@functools.wraps(fn)
def graph_break_as_hard_error(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Unsupported as e:
raise UnsafeScriptObjectError(e.msg) from e
return graph_break_as_hard_error
return deco
class TorchScriptObjectVariable(UserDefinedObjectVariable):
_fake_script_object_cache: Dict[int, "TorchScriptObjectVariable"] = {}
@classmethod
def is_matching_cls(cls, user_cls: type):
return issubclass(user_cls, torch.ScriptObject)
@staticmethod
def create(proxy, value, **options):
return TorchScriptObjectVariable(proxy, value, **options)
def __init__(self, proxy, value, source, **kwargs):
super().__init__(value, **kwargs)
self.proxy = proxy
self.proxy.node.meta["example_value"] = value
self.source = source
def as_proxy(self):
return self.proxy
@_raise_hard_error_if_graph_break(
"Dynamo cannot safely trace script object due to graph break."
)
def var_getattr(self, tx, name: str) -> VariableTracker:
from torch._higher_order_ops.torchbind import call_torchbind
from ..source import AttrSource
from .higher_order_ops import TorchHigherOrderOperatorVariable
method = getattr(self.value, name, None)
if method is None:
unimplemented(
f"FakeScriptObject doesn't define method {name}. Did you forget to implement it in the fake class?"
)
if not callable(method):
unimplemented(
"Only method calls on TorchScript objects can be supported safely."
" Please use method calls instead of attribute access."
)
return TorchHigherOrderOperatorVariable.make(
call_torchbind,
source=AttrSource(self.source, name),
script_obj_var=self,
method_name=name,
)
# We only support method calls on script objects. Interpreting the bytecodes
# should go through var_getattr then call_function instead of call_method.
#
# However, it's possible for call_method to be used directly e.g. for __setattr__.
@_raise_hard_error_if_graph_break(
"Dynamo cannot safely trace script object due to graph break."
)
def call_method(self, tx, name, args, kwargs):
unimplemented(f"call method {name} on script object is not safe.")

View File

@ -205,38 +205,3 @@ class AbstractImplCtx:
result, min=min, max=max result, min=min, max=max
) )
return result return result
def to_fake_tensor(self, tensor: torch.Tensor):
"""
Creates a fake tensor from a concrete tensor. Note: this is not needed for register_fake.
This is useful for register_fake_class (which is necessary for torch.compile) for custom class.
Users need to implement a from_real method that takes a real custom object and creates a fake
custom object. Users can use this API to create fake tensors for the tensor states in the custom object.
Args:
tensor (torch.Tensor): A concrete tensor.
Example::
>>> import torch
>>> @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") # xdoctest: +SKIP
... class FakeTensorQueue:
... def __init__(self, q):
... self.queue = q
...
... @classmethod
... def from_real(cls, real_tq):
... ctx = torch.library.get_ctx()
... fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
... return cls(fake_queue)
...
... def push(self, x):
... self.queue.append(x)
...
... def pop(self):
... return self.queue.pop(0)
...
... def size(self):
... return len(self.queue)
"""
return self._fake_mode.from_tensor(tensor)

View File

@ -65,8 +65,35 @@ class FakeClassRegistry:
global_fake_class_registry = FakeClassRegistry() global_fake_class_registry = FakeClassRegistry()
# TODO: add this check at compile time for __obj_flatten__.
def _check_valid_flat_script_obj(flat_x):
if not isinstance(flat_x, tuple):
raise RuntimeError("Expect flat x to be a tuple.")
for tp in flat_x:
if not isinstance(tp, tuple):
raise RuntimeError("Expect flat x to be a tuple of tuples.")
if not len(tp) == 2 or not isinstance(tp[0], str):
raise RuntimeError(
"Expect element of flat x to be a tuple of two elements with first element being a string"
)
def to_fake_obj(fake_mode, x: torch.ScriptObject) -> FakeScriptObject: def to_fake_obj(fake_mode, x: torch.ScriptObject) -> FakeScriptObject:
fake_x = _fake_obj_from_real(fake_mode, x) import torch.utils._pytree as pytree
flat_x = x.__obj_flatten__() # type: ignore[attr-defined]
_check_valid_flat_script_obj(flat_x)
fake_flattened = pytree.tree_map_only(
torch.Tensor,
lambda t: fake_mode.from_tensor(t),
flat_x,
)
fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened)
def _call_torchbind(method_name): def _call_torchbind(method_name):
from torch._higher_order_ops.torchbind import call_torchbind from torch._higher_order_ops.torchbind import call_torchbind
@ -134,14 +161,13 @@ def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal]
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
class FakeTensorQueue: class FakeTensorQueue:
def __init__(self, q): def __init__(self, queue):
self.queue = q self.queue = queue
@classmethod @classmethod
def from_real(cls, real_tq): def __obj_unflatten__(cls, flattened_ctx):
ctx = torch.library.get_ctx() ctx = {flattened_ctx[0]: flattened_ctx[1]}
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.clone_queue()] return cls(**ctx)
return cls(fake_queue)
def push(self, x): def push(self, x):
self.queue.append(x) self.queue.append(x)
@ -162,7 +188,9 @@ def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal]
from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None) from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
if not from_method: if not from_method:
raise RuntimeError(f"{fake_class} doesn't define a classmethod from_real.") raise RuntimeError(
f"{fake_class} doesn't define a classmethod {_CONVERT_FROM_REAL_NAME}."
)
if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod): if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod):
raise RuntimeError( raise RuntimeError(
@ -221,7 +249,7 @@ def _find_fake_class_for_script_object(x: torch.ScriptObject) -> Any:
return fake_class return fake_class
_CONVERT_FROM_REAL_NAME = "from_real" _CONVERT_FROM_REAL_NAME = "__obj_unflatten__"
def _fake_obj_from_real(fake_mode, x) -> Any: def _fake_obj_from_real(fake_mode, x) -> Any:

View File

@ -58,6 +58,7 @@ hop_that_doesnt_have_opinfo_test_allowlist = [
"with_effects", "with_effects",
"strict_mode", "strict_mode",
"_export_tracepoint", "_export_tracepoint",
"call_torchbind",
] ]
torch.library.define( torch.library.define(

View File

@ -1,22 +1,35 @@
import contextlib import contextlib
from typing import Optional
import torch import torch
_TORCHBIND_IMPLS_INITIALIZED = False _TORCHBIND_IMPLS_INITIALIZED = False
_TENSOR_QUEUE_GLOBAL_TEST: Optional[torch.ScriptObject] = None
def init_torchbind_implementations(): def init_torchbind_implementations():
global _TORCHBIND_IMPLS_INITIALIZED global _TORCHBIND_IMPLS_INITIALIZED
global _TENSOR_QUEUE_GLOBAL_TEST
if _TORCHBIND_IMPLS_INITIALIZED: if _TORCHBIND_IMPLS_INITIALIZED:
return return
load_torchbind_test_lib() load_torchbind_test_lib()
register_fake_operators() register_fake_operators()
register_fake_classes() register_fake_classes()
_TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue()
_TORCHBIND_IMPLS_INITIALIZED = True _TORCHBIND_IMPLS_INITIALIZED = True
def _empty_tensor_queue() -> torch.ScriptObject:
return torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
# put these under a function because the corresponding library might not be loaded yet. # put these under a function because the corresponding library might not be loaded yet.
def register_fake_operators(): def register_fake_operators():
@torch.library.register_fake("_TorchScriptTesting::takes_foo_python_meta") @torch.library.register_fake("_TorchScriptTesting::takes_foo_python_meta")
@ -67,25 +80,23 @@ def register_fake_classes():
self.y = y self.y = y
@classmethod @classmethod
def from_real(cls, foo): def __obj_unflatten__(cls, flattend_foo):
(x, y), _ = foo.__getstate__() return cls(**dict(flattend_foo))
return cls(x, y)
def add_tensor(self, z): def add_tensor(self, z):
return (self.x + self.y) * z return (self.x + self.y) * z
@torch._library.register_fake_class("_TorchScriptTesting::_ContainsTensor") @torch._library.register_fake_class("_TorchScriptTesting::_ContainsTensor")
class FakeContainsTensor: class FakeContainsTensor:
def __init__(self, x: torch.Tensor): def __init__(self, t: torch.Tensor):
self.x = x self.t = t
@classmethod @classmethod
def from_real(cls, foo): def __obj_unflatten__(cls, flattend_foo):
ctx = torch.library.get_ctx() return cls(**dict(flattend_foo))
return cls(ctx.to_fake_tensor(foo.get()))
def get(self): def get(self):
return self.x return self.t
def load_torchbind_test_lib(): def load_torchbind_test_lib():