mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dynamo] support torchbind object input (#124978)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124978 Approved by: https://github.com/jansel
This commit is contained in:
parent
c165a8e71d
commit
461ffaaaf3
|
|
@ -59,6 +59,10 @@ struct Foo : torch::CustomClassHolder {
|
||||||
bool eq(c10::intrusive_ptr<Foo> other) {
|
bool eq(c10::intrusive_ptr<Foo> other) {
|
||||||
return this->x == other->x && this->y == other->y;
|
return this->x == other->x && this->y == other->y;
|
||||||
}
|
}
|
||||||
|
std::tuple<std::tuple<std::string, int64_t>, std::tuple<std::string, int64_t>>
|
||||||
|
__obj_flatten__() {
|
||||||
|
return std::tuple(std::tuple("x", this->x), std::tuple("y", this->y));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct _StaticMethod : torch::CustomClassHolder {
|
struct _StaticMethod : torch::CustomClassHolder {
|
||||||
|
|
@ -199,6 +203,10 @@ struct TensorQueue : torch::CustomClassHolder {
|
||||||
return raw_queue;
|
return raw_queue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<std::tuple<std::string, std::vector<at::Tensor>>> __obj_flatten__() {
|
||||||
|
return std::tuple(std::tuple("queue", this->get_raw_queue()));
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::deque<at::Tensor> queue_;
|
std::deque<at::Tensor> queue_;
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
|
|
@ -370,6 +378,10 @@ struct ContainsTensor : public torch::CustomClassHolder {
|
||||||
return t_;
|
return t_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<std::tuple<std::string, at::Tensor>> __obj_flatten__() {
|
||||||
|
return std::tuple(std::tuple("t", this->t_));
|
||||||
|
}
|
||||||
|
|
||||||
at::Tensor t_;
|
at::Tensor t_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -417,6 +429,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
|
||||||
.def("add_tensor", &Foo::add_tensor)
|
.def("add_tensor", &Foo::add_tensor)
|
||||||
.def("__eq__", &Foo::eq)
|
.def("__eq__", &Foo::eq)
|
||||||
.def("combine", &Foo::combine)
|
.def("combine", &Foo::combine)
|
||||||
|
.def("__obj_flatten__", &Foo::__obj_flatten__)
|
||||||
.def_pickle(
|
.def_pickle(
|
||||||
[](c10::intrusive_ptr<Foo> self) { // __getstate__
|
[](c10::intrusive_ptr<Foo> self) { // __getstate__
|
||||||
return std::vector<int64_t>{self->x, self->y};
|
return std::vector<int64_t>{self->x, self->y};
|
||||||
|
|
@ -424,6 +437,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
|
||||||
[](std::vector<int64_t> state) { // __setstate__
|
[](std::vector<int64_t> state) { // __setstate__
|
||||||
return c10::make_intrusive<Foo>(state[0], state[1]);
|
return c10::make_intrusive<Foo>(state[0], state[1]);
|
||||||
});
|
});
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"takes_foo(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
|
"takes_foo(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
|
||||||
m.def(
|
m.def(
|
||||||
|
|
@ -551,6 +565,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
|
||||||
m.class_<ContainsTensor>("_ContainsTensor")
|
m.class_<ContainsTensor>("_ContainsTensor")
|
||||||
.def(torch::init<at::Tensor>())
|
.def(torch::init<at::Tensor>())
|
||||||
.def("get", &ContainsTensor::get)
|
.def("get", &ContainsTensor::get)
|
||||||
|
.def("__obj_flatten__", &ContainsTensor::__obj_flatten__)
|
||||||
.def_pickle(
|
.def_pickle(
|
||||||
// __getstate__
|
// __getstate__
|
||||||
[](const c10::intrusive_ptr<ContainsTensor>& self) -> at::Tensor {
|
[](const c10::intrusive_ptr<ContainsTensor>& self) -> at::Tensor {
|
||||||
|
|
@ -568,6 +583,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
|
||||||
.def("size", &TensorQueue::size)
|
.def("size", &TensorQueue::size)
|
||||||
.def("clone_queue", &TensorQueue::clone_queue)
|
.def("clone_queue", &TensorQueue::clone_queue)
|
||||||
.def("get_raw_queue", &TensorQueue::get_raw_queue)
|
.def("get_raw_queue", &TensorQueue::get_raw_queue)
|
||||||
|
.def("__obj_flatten__", &TensorQueue::__obj_flatten__)
|
||||||
.def_pickle(
|
.def_pickle(
|
||||||
// __getstate__
|
// __getstate__
|
||||||
[](const c10::intrusive_ptr<TensorQueue>& self)
|
[](const c10::intrusive_ptr<TensorQueue>& self)
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,11 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
|
from torch._dynamo.testing import EagerAndRecordGraphs
|
||||||
from torch._functorch.aot_autograd import aot_export_module
|
from torch._functorch.aot_autograd import aot_export_module
|
||||||
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
||||||
|
|
||||||
|
from torch._higher_order_ops.wrap import wrap
|
||||||
from torch._library.fake_class_registry import FakeScriptObject
|
from torch._library.fake_class_registry import FakeScriptObject
|
||||||
from torch.export import export
|
from torch.export import export
|
||||||
from torch.export._trace import _export
|
from torch.export._trace import _export
|
||||||
|
|
@ -16,7 +19,39 @@ from torch.testing._internal.common_utils import (
|
||||||
skipIfTorchDynamo,
|
skipIfTorchDynamo,
|
||||||
TestCase,
|
TestCase,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
|
from torch.testing._internal.torchbind_impls import (
|
||||||
|
_empty_tensor_queue,
|
||||||
|
init_torchbind_implementations,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _assertEqualSkipScriptObject(test_case, exp, actual):
|
||||||
|
flat_exp = pytree.tree_leaves(exp)
|
||||||
|
flat_actual = pytree.tree_leaves(actual)
|
||||||
|
test_case.assertEqual(len(flat_exp), len(flat_actual))
|
||||||
|
for a, b in zip(flat_exp, flat_actual):
|
||||||
|
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
|
||||||
|
continue
|
||||||
|
test_case.assertEqual(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_script_obj_equal(test_case, a: torch.ScriptObject, b: torch.ScriptObject):
|
||||||
|
return test_case.assertEqual(
|
||||||
|
a._type().qualified_name(), b._type().qualified_name()
|
||||||
|
) and test_case.assertEqual(a.__obj_flatten__(), b.__obj_flatten__())
|
||||||
|
|
||||||
|
|
||||||
|
def _assertEqualScriptObject(
|
||||||
|
test_case, exp, actual, check_obj_eq=_check_script_obj_equal
|
||||||
|
):
|
||||||
|
flat_exp = pytree.tree_leaves(exp)
|
||||||
|
flat_actual = pytree.tree_leaves(actual)
|
||||||
|
test_case.assertEqual(len(flat_exp), len(flat_actual))
|
||||||
|
for a, b in zip(flat_exp, flat_actual):
|
||||||
|
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
|
||||||
|
check_obj_eq(test_case, a, b)
|
||||||
|
else:
|
||||||
|
test_case.assertEqual(a, b)
|
||||||
|
|
||||||
|
|
||||||
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
|
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
|
||||||
|
|
@ -37,9 +72,8 @@ class TestExportTorchbind(TestCase):
|
||||||
self.y = y
|
self.y = y
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_real(cls, foo):
|
def __obj_unflatten__(cls, flattend_foo):
|
||||||
(x, y), _ = foo.__getstate__()
|
return cls(**dict(flattend_foo))
|
||||||
return cls(x, y)
|
|
||||||
|
|
||||||
def add_tensor(self, z):
|
def add_tensor(self, z):
|
||||||
test.foo_add_tensor_counter += 1
|
test.foo_add_tensor_counter += 1
|
||||||
|
|
@ -47,14 +81,12 @@ class TestExportTorchbind(TestCase):
|
||||||
|
|
||||||
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
|
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
|
||||||
class FakeTensorQueue:
|
class FakeTensorQueue:
|
||||||
def __init__(self, q):
|
def __init__(self, queue):
|
||||||
self.queue = q
|
self.queue = queue
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_real(cls, real_tq):
|
def __obj_unflatten__(cls, flattened_ctx):
|
||||||
ctx = torch.library.get_ctx()
|
return cls(**dict(flattened_ctx))
|
||||||
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
|
|
||||||
return cls(fake_queue)
|
|
||||||
|
|
||||||
def push(self, x):
|
def push(self, x):
|
||||||
test.tq_push_counter += 1
|
test.tq_push_counter += 1
|
||||||
|
|
@ -89,15 +121,6 @@ class TestExportTorchbind(TestCase):
|
||||||
"_TorchScriptTesting::_TensorQueue"
|
"_TorchScriptTesting::_TensorQueue"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _assertEqualSkipScriptObject(self, exp, actual):
|
|
||||||
flat_exp = pytree.tree_leaves(exp)
|
|
||||||
flat_actual = pytree.tree_leaves(actual)
|
|
||||||
self.assertEqual(len(flat_exp), len(flat_actual))
|
|
||||||
for a, b in zip(flat_exp, flat_actual):
|
|
||||||
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
|
|
||||||
continue
|
|
||||||
self.assertEqual(a, b)
|
|
||||||
|
|
||||||
def _test_export_same_as_eager(
|
def _test_export_same_as_eager(
|
||||||
self, f, args, kwargs=None, strict=True, pre_dispatch=False
|
self, f, args, kwargs=None, strict=True, pre_dispatch=False
|
||||||
):
|
):
|
||||||
|
|
@ -532,7 +555,7 @@ def forward(self, arg0_1, arg1_1):
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
mod.check_tq_is_fake = False
|
mod.check_tq_is_fake = False
|
||||||
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
|
_assertEqualSkipScriptObject(self, gm(tq, x), mod(tq1, x))
|
||||||
|
|
||||||
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
|
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
|
||||||
def test_make_fx_tensor_queue_methods_fakify_internal_states(
|
def test_make_fx_tensor_queue_methods_fakify_internal_states(
|
||||||
|
|
@ -592,7 +615,7 @@ def forward(self, arg0_1, arg1_1):
|
||||||
)
|
)
|
||||||
# turn off tq type checking in eager execution
|
# turn off tq type checking in eager execution
|
||||||
mod.check_tq_is_fake = False
|
mod.check_tq_is_fake = False
|
||||||
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
|
_assertEqualSkipScriptObject(self, gm(tq, x), mod(tq1, x))
|
||||||
self.assertEqual(tq.size(), 0)
|
self.assertEqual(tq.size(), 0)
|
||||||
self.assertEqual(tq1.size(), 0)
|
self.assertEqual(tq1.size(), 0)
|
||||||
|
|
||||||
|
|
@ -769,7 +792,7 @@ def forward(self, arg0_1, arg1_1):
|
||||||
add = torch.ops.aten.add.Tensor(queue_pop_1, 0); queue_pop_1 = None
|
add = torch.ops.aten.add.Tensor(queue_pop_1, 0); queue_pop_1 = None
|
||||||
return (sub, add, arg0_1)""",
|
return (sub, add, arg0_1)""",
|
||||||
)
|
)
|
||||||
self._assertEqualSkipScriptObject(gm(tq1, x), mod(tq2, x))
|
_assertEqualSkipScriptObject(self, gm(tq1, x), mod(tq2, x))
|
||||||
|
|
||||||
def test_aot_export_tensor_queue_operators(self):
|
def test_aot_export_tensor_queue_operators(self):
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
|
|
@ -829,6 +852,377 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompileTorchbind(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
init_torchbind_implementations()
|
||||||
|
|
||||||
|
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
|
||||||
|
class FakeTensorQueue:
|
||||||
|
def __init__(self, queue):
|
||||||
|
self.queue = queue
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __obj_unflatten__(cls, flattened_ctx):
|
||||||
|
return cls(**dict(flattened_ctx))
|
||||||
|
|
||||||
|
def push(self, x):
|
||||||
|
self.queue.append(x)
|
||||||
|
|
||||||
|
def pop(self):
|
||||||
|
return self.queue.pop(0)
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return len(self.queue)
|
||||||
|
|
||||||
|
torch._dynamo.reset()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
torch._dynamo.reset()
|
||||||
|
|
||||||
|
def test_compile_script_object_input(self):
|
||||||
|
backend = EagerAndRecordGraphs()
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.check_tq_is_fake = True
|
||||||
|
|
||||||
|
def forward(self, tq, x):
|
||||||
|
tq.push(x.cos())
|
||||||
|
tq.push(x.sin())
|
||||||
|
x_sin = tq.pop() - tq.size()
|
||||||
|
return x_sin, tq
|
||||||
|
|
||||||
|
mod = Model()
|
||||||
|
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
|
||||||
|
torch.empty(
|
||||||
|
0,
|
||||||
|
).fill_(-1)
|
||||||
|
)
|
||||||
|
tq2 = torch.classes._TorchScriptTesting._TensorQueue(
|
||||||
|
torch.empty(
|
||||||
|
0,
|
||||||
|
).fill_(-1)
|
||||||
|
)
|
||||||
|
tq3 = torch.classes._TorchScriptTesting._TensorQueue(
|
||||||
|
torch.empty(
|
||||||
|
0,
|
||||||
|
).fill_(-1)
|
||||||
|
)
|
||||||
|
tq4 = torch.classes._TorchScriptTesting._TensorQueue(
|
||||||
|
torch.empty(
|
||||||
|
0,
|
||||||
|
).fill_(-1)
|
||||||
|
)
|
||||||
|
x = torch.randn(2, 3)
|
||||||
|
ret = torch.compile(mod, backend=backend)(tq1, x)
|
||||||
|
eager_ret = mod(tq2, x)
|
||||||
|
_assertEqualSkipScriptObject(self, ret, eager_ret)
|
||||||
|
self.assertEqual(ret[1].size(), eager_ret[1].size())
|
||||||
|
self.assertEqual(ret[1].pop(), eager_ret[1].pop())
|
||||||
|
# Note that dynamo captured graph
|
||||||
|
# does not return L_tq_ as output. This is because it's able
|
||||||
|
# to detect that L_tq_ is an input therefore don't return
|
||||||
|
# it as graph output. Related logic is in dynamo/codegen.py
|
||||||
|
self.assertExpectedInline(
|
||||||
|
backend.graphs[0].code.strip(),
|
||||||
|
"""\
|
||||||
|
def forward(self, L_tq_ : torch.ScriptObject, L_x_ : torch.Tensor):
|
||||||
|
l_tq_ = L_tq_
|
||||||
|
l_x_ = L_x_
|
||||||
|
cos = l_x_.cos()
|
||||||
|
call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', cos); cos = None
|
||||||
|
sin = l_x_.sin(); l_x_ = None
|
||||||
|
call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', sin); sin = None
|
||||||
|
call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'pop')
|
||||||
|
call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None
|
||||||
|
x_sin = call_torchbind_2 - 1; call_torchbind_2 = None
|
||||||
|
return (x_sin,)""",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_compile_script_object_input_guards(self):
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.check_tq_is_fake = True
|
||||||
|
|
||||||
|
def forward(self, tq, x):
|
||||||
|
tq.push(x.cos())
|
||||||
|
tq.push(x.sin())
|
||||||
|
x_sin = tq.pop() - tq.size()
|
||||||
|
return x_sin, tq
|
||||||
|
|
||||||
|
mod = Model()
|
||||||
|
cnt = torch._dynamo.testing.CompileCounter()
|
||||||
|
x = torch.randn(2, 3)
|
||||||
|
|
||||||
|
tq1 = _empty_tensor_queue()
|
||||||
|
torch.compile(mod, backend=cnt)(tq1, x)
|
||||||
|
self.assertEqual(cnt.frame_count, 1)
|
||||||
|
|
||||||
|
tq2 = _empty_tensor_queue()
|
||||||
|
for _ in range(10):
|
||||||
|
tq2.push(torch.randn(4, 5, requires_grad=False))
|
||||||
|
torch.compile(mod, backend=cnt)(tq2, x)
|
||||||
|
# Queue length change causes re-compile
|
||||||
|
self.assertEqual(cnt.frame_count, 2)
|
||||||
|
|
||||||
|
tq3 = _empty_tensor_queue()
|
||||||
|
tq3.push(torch.randn(2, 3, requires_grad=False))
|
||||||
|
torch.compile(mod, backend=cnt)(tq3, x)
|
||||||
|
# Tensor in queue changes shape causes re-compile
|
||||||
|
self.assertEqual(cnt.frame_count, 3)
|
||||||
|
|
||||||
|
tq4 = _empty_tensor_queue()
|
||||||
|
tq4.push(torch.randn(2, 3, requires_grad=False))
|
||||||
|
torch.compile(mod, backend=cnt)(tq4, x)
|
||||||
|
# No recompile
|
||||||
|
self.assertEqual(cnt.frame_count, 3)
|
||||||
|
|
||||||
|
tq5 = _empty_tensor_queue()
|
||||||
|
tq5.push(torch.randn(2, 3, requires_grad=True))
|
||||||
|
torch.compile(mod, backend=cnt)(tq5, x)
|
||||||
|
# Tensor in queue changes dispatch key causes re-compile
|
||||||
|
self.assertEqual(cnt.frame_count, 4)
|
||||||
|
|
||||||
|
tq6 = _empty_tensor_queue()
|
||||||
|
tq6.push(torch.randn(2, 3, requires_grad=True, dtype=torch.float64))
|
||||||
|
torch.compile(mod, backend=cnt)(tq6, x)
|
||||||
|
# Tensor in queue changes dtype causes re-compile
|
||||||
|
self.assertEqual(cnt.frame_count, 5)
|
||||||
|
|
||||||
|
def test_compile_script_object_input_automatic_dynamic_shape(self):
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.check_tq_is_fake = True
|
||||||
|
|
||||||
|
def forward(self, tq, x):
|
||||||
|
tq.push(x.cos())
|
||||||
|
tq.push(x.sin())
|
||||||
|
x_sin = tq.pop() - tq.size()
|
||||||
|
return x_sin, tq
|
||||||
|
|
||||||
|
mod = Model()
|
||||||
|
cnt = torch._dynamo.testing.CompileCounter()
|
||||||
|
x = torch.randn(2, 3)
|
||||||
|
|
||||||
|
tq1 = _empty_tensor_queue()
|
||||||
|
tq1.push(torch.randn(2, 3, requires_grad=False))
|
||||||
|
torch.compile(mod, backend=cnt)(tq1, x)
|
||||||
|
self.assertEqual(cnt.frame_count, 1)
|
||||||
|
|
||||||
|
tq2 = _empty_tensor_queue()
|
||||||
|
# make first tensor's secon dim dynamic
|
||||||
|
tq2.push(torch.randn(2, 4, requires_grad=False))
|
||||||
|
torch.compile(mod, backend=cnt)(tq2, x)
|
||||||
|
self.assertEqual(cnt.frame_count, 2)
|
||||||
|
|
||||||
|
tq3 = _empty_tensor_queue()
|
||||||
|
tq3.push(torch.randn(2, 5, requires_grad=False))
|
||||||
|
# should have no-recompilation
|
||||||
|
torch.compile(mod, backend=cnt)(tq3, x)
|
||||||
|
self.assertEqual(cnt.frame_count, 2)
|
||||||
|
|
||||||
|
def test_compile_error_on_input_aliasing_contents(self):
|
||||||
|
backend = EagerAndRecordGraphs()
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.check_tq_is_fake = True
|
||||||
|
|
||||||
|
def forward(self, tq, x):
|
||||||
|
return x.sin(), tq.pop().cos()
|
||||||
|
|
||||||
|
x = torch.randn(2, 3)
|
||||||
|
mod = Model()
|
||||||
|
|
||||||
|
tq1 = _empty_tensor_queue()
|
||||||
|
tq1.push(x)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "is alising"):
|
||||||
|
torch.compile(mod, backend=backend)(tq1, x)
|
||||||
|
|
||||||
|
def test_compile_error_on_script_obj_setattr(self):
|
||||||
|
def setattr_f(tq):
|
||||||
|
tq.a = 1
|
||||||
|
return tq
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "call method __setattr__ on script object is not safe"
|
||||||
|
):
|
||||||
|
torch.compile(setattr_f, backend="eager")(_empty_tensor_queue())
|
||||||
|
|
||||||
|
def test_compile_error_on_script_obj_missing_attr(self):
|
||||||
|
def setattr_f(tq):
|
||||||
|
return tq._not_defined_attr
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "doesn't define method _not_defined_attr"
|
||||||
|
):
|
||||||
|
torch.compile(setattr_f, backend="eager")(_empty_tensor_queue())
|
||||||
|
|
||||||
|
def test_compile_body_aliasing_contents(self):
|
||||||
|
backend = EagerAndRecordGraphs()
|
||||||
|
|
||||||
|
def f(tq, x):
|
||||||
|
x1 = x.view(-1)
|
||||||
|
x2 = x.permute(1, 0)
|
||||||
|
tq.push(x1)
|
||||||
|
tq.push(x2)
|
||||||
|
return x1 - tq.size(), x2 + tq.size(), tq
|
||||||
|
|
||||||
|
x = torch.randn(2, 3)
|
||||||
|
_assertEqualScriptObject(
|
||||||
|
self,
|
||||||
|
f(_empty_tensor_queue(), x),
|
||||||
|
torch.compile(f, backend=backend)(_empty_tensor_queue(), x),
|
||||||
|
)
|
||||||
|
if not torch._dynamo.is_compiling():
|
||||||
|
self.assertExpectedInline(
|
||||||
|
backend.graphs[0].code.strip(),
|
||||||
|
"""\
|
||||||
|
def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject):
|
||||||
|
l_x_ = L_x_
|
||||||
|
l_tq_ = L_tq_
|
||||||
|
x1 = l_x_.view(-1)
|
||||||
|
x2 = l_x_.permute(1, 0); l_x_ = None
|
||||||
|
call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', x1)
|
||||||
|
call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', x2)
|
||||||
|
call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'size')
|
||||||
|
sub = x1 - 2; x1 = None
|
||||||
|
call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None
|
||||||
|
add = x2 + 2; x2 = None
|
||||||
|
return (sub, add)""",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_compile_error_on_non_fakified_method(self):
|
||||||
|
backend = EagerAndRecordGraphs()
|
||||||
|
|
||||||
|
def f(tq, x):
|
||||||
|
x1 = x.view(-1)
|
||||||
|
x2 = x.permute(1, 0)
|
||||||
|
tq.push(x1)
|
||||||
|
tq.push(x2)
|
||||||
|
# though real tensor queue implemented a method clone_queue,
|
||||||
|
# The fakified version doesn't.
|
||||||
|
flat_obj = tq.clone_queue()
|
||||||
|
return flat_obj
|
||||||
|
|
||||||
|
x = torch.randn(2, 3)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "FakeScriptObject doesn't define method"
|
||||||
|
):
|
||||||
|
torch.compile(f, backend=backend)(_empty_tensor_queue(), x)
|
||||||
|
|
||||||
|
def test_compile_obj_as_hop_input(self):
|
||||||
|
def f(tq, x):
|
||||||
|
def fn(tq, x):
|
||||||
|
tq.push(x)
|
||||||
|
return x.sin()
|
||||||
|
|
||||||
|
return wrap(fn, tq, x)
|
||||||
|
|
||||||
|
x = torch.randn(2, 3)
|
||||||
|
_assertEqualScriptObject(
|
||||||
|
self,
|
||||||
|
f(_empty_tensor_queue(), x),
|
||||||
|
torch.compile(f, backend="eager")(_empty_tensor_queue(), x),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_compile_obj_closure(self):
|
||||||
|
def f(x):
|
||||||
|
def inner_f(x):
|
||||||
|
tq.push(x.sin())
|
||||||
|
|
||||||
|
inner_f(x)
|
||||||
|
return tq.pop(), tq
|
||||||
|
|
||||||
|
opt_f = torch.compile(f, backend="eager")
|
||||||
|
|
||||||
|
tq = _empty_tensor_queue()
|
||||||
|
x = torch.randn(3, 2)
|
||||||
|
_assertEqualScriptObject(self, f(x), opt_f(x))
|
||||||
|
|
||||||
|
def test_compile_global_obj(self):
|
||||||
|
global _TENSOR_QUEUE_GLOBAL_TEST
|
||||||
|
_TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue()
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
_TENSOR_QUEUE_GLOBAL_TEST.push(x.sin())
|
||||||
|
return _TENSOR_QUEUE_GLOBAL_TEST.pop(), _TENSOR_QUEUE_GLOBAL_TEST
|
||||||
|
|
||||||
|
opt_f = torch.compile(f, backend="eager")
|
||||||
|
x = torch.randn(3, 2)
|
||||||
|
eager_ret = f(x)
|
||||||
|
opt_ret = opt_f(x)
|
||||||
|
_assertEqualScriptObject(self, eager_ret, opt_ret)
|
||||||
|
|
||||||
|
def test_compile_obj_graph_breaks(self):
|
||||||
|
cnt = torch._dynamo.testing.CompileCounter()
|
||||||
|
|
||||||
|
def f(tq, x):
|
||||||
|
tq.push(x.sin())
|
||||||
|
tq.push(x.sin())
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
tq.pop()
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
tq.push(x.cos() + tq.size())
|
||||||
|
torch._dynamo.graph_break()
|
||||||
|
tq.push(x.cos() - tq.size())
|
||||||
|
return x, tq.pop(), tq
|
||||||
|
|
||||||
|
opt_f = torch.compile(f, backend=cnt)
|
||||||
|
x = torch.randn(3, 2)
|
||||||
|
_assertEqualScriptObject(
|
||||||
|
self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x)
|
||||||
|
)
|
||||||
|
self.assertEqual(cnt.frame_count, 4)
|
||||||
|
|
||||||
|
def test_compile_obj_attributes(self):
|
||||||
|
backend = EagerAndRecordGraphs()
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.tq = _empty_tensor_queue()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
self.tq.push(x)
|
||||||
|
return self.tq.pop()
|
||||||
|
|
||||||
|
x = torch.randn(2, 3)
|
||||||
|
opt_f = torch.compile(Model(), backend=backend)
|
||||||
|
_assertEqualScriptObject(self, Model()(x), opt_f(x))
|
||||||
|
self.assertEqual(len(backend.graphs), 1)
|
||||||
|
# lifted as input. In the future, we would want to cosolidate this
|
||||||
|
# with non-strict behavior, where they're set as attributes.
|
||||||
|
self.assertExpectedInline(
|
||||||
|
backend.graphs[0].code.strip(),
|
||||||
|
"""\
|
||||||
|
def forward(self, L_self_tq : torch.ScriptObject, L_x_ : torch.Tensor):
|
||||||
|
l_self_tq = L_self_tq
|
||||||
|
l_x_ = L_x_
|
||||||
|
call_torchbind = torch.ops.higher_order.call_torchbind(l_self_tq, 'push', l_x_); l_x_ = None
|
||||||
|
call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_self_tq, 'pop'); l_self_tq = None
|
||||||
|
return (call_torchbind_1,)""",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_compile_obj_torchbind_op(self):
|
||||||
|
def f(tq, x):
|
||||||
|
torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
|
||||||
|
torch.ops._TorchScriptTesting.queue_push(tq, x.cos() + 1)
|
||||||
|
torch.ops._TorchScriptTesting.queue_pop(tq)
|
||||||
|
torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
|
||||||
|
return tq.pop(), tq.pop() + tq.size(), tq
|
||||||
|
|
||||||
|
opt_f = torch.compile(f, backend="eager")
|
||||||
|
x = torch.randn(2)
|
||||||
|
_assertEqualScriptObject(
|
||||||
|
self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
|
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
|
||||||
class TestRegisterFakeClass(TestCase):
|
class TestRegisterFakeClass(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
@ -845,7 +1239,9 @@ class TestRegisterFakeClass(TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_register_fake_class_no_from_real(self):
|
def test_register_fake_class_no_from_real(self):
|
||||||
with self.assertRaisesRegex(RuntimeError, "define a classmethod from_real"):
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "define a classmethod __obj_unflatten__"
|
||||||
|
):
|
||||||
|
|
||||||
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
||||||
class InvalidFakeFoo:
|
class InvalidFakeFoo:
|
||||||
|
|
@ -861,9 +1257,8 @@ class TestRegisterFakeClass(TestCase):
|
||||||
self.x = x
|
self.x = x
|
||||||
self.y = y
|
self.y = y
|
||||||
|
|
||||||
def from_real(self, foo_obj):
|
def __obj_unflatten__(cls, flattend_foo): # noqa: B902
|
||||||
x, y = foo_obj.__getstate__()
|
return cls(**dict(flattend_foo))
|
||||||
return FakeFoo(x, y)
|
|
||||||
|
|
||||||
def test_register_fake_class_valid(self):
|
def test_register_fake_class_valid(self):
|
||||||
class FakeFoo:
|
class FakeFoo:
|
||||||
|
|
@ -872,9 +1267,8 @@ class TestRegisterFakeClass(TestCase):
|
||||||
self.y = y
|
self.y = y
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_real(cls, foo_obj):
|
def __obj_unflatten__(cls, flattend_foo):
|
||||||
x, y = foo_obj.__getstate__()
|
return cls(**dict(flattend_foo))
|
||||||
return cls(x, y)
|
|
||||||
|
|
||||||
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
|
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -573,9 +573,8 @@ class TestUnflatten(TestCase):
|
||||||
self.y = y
|
self.y = y
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_real(cls, foo):
|
def __obj_unflatten__(cls, flat_ctx):
|
||||||
(x, y), _ = foo.__getstate__()
|
return cls(**dict(flat_ctx))
|
||||||
return cls(x, y)
|
|
||||||
|
|
||||||
def add_tensor(self, z):
|
def add_tensor(self, z):
|
||||||
return (self.x + self.y) * z
|
return (self.x + self.y) * z
|
||||||
|
|
|
||||||
|
|
@ -17,34 +17,21 @@ from torch.testing import FileCheck
|
||||||
from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM80OrLater
|
from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM80OrLater
|
||||||
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
|
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
find_library_location,
|
|
||||||
IS_FBCODE,
|
|
||||||
IS_MACOS,
|
|
||||||
IS_SANDCASTLE,
|
|
||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
run_tests,
|
run_tests,
|
||||||
TEST_CUDA,
|
TEST_CUDA,
|
||||||
TEST_WITH_ROCM,
|
TEST_WITH_ROCM,
|
||||||
TestCase,
|
TestCase,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
|
||||||
from torch.utils.hooks import RemovableHandle
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
|
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||||
class TestWithEffects(TestCase):
|
class TestWithEffects(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
if IS_MACOS:
|
init_torchbind_implementations()
|
||||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
|
||||||
elif IS_SANDCASTLE or IS_FBCODE:
|
|
||||||
torch.ops.load_library(
|
|
||||||
"//caffe2/test/cpp/jit:test_custom_class_registrations"
|
|
||||||
)
|
|
||||||
elif IS_WINDOWS:
|
|
||||||
lib_file_path = find_library_location("torchbind_test.dll")
|
|
||||||
torch.ops.load_library(str(lib_file_path))
|
|
||||||
else:
|
|
||||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
|
||||||
torch.ops.load_library(str(lib_file_path))
|
|
||||||
|
|
||||||
def test_print(self):
|
def test_print(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -171,6 +171,10 @@ class UserStopIteration(TorchDynamoException):
|
||||||
self.value = None
|
self.value = None
|
||||||
|
|
||||||
|
|
||||||
|
class UnsafeScriptObjectError(TorchDynamoException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UncapturedHigherOrderOpError(TorchDynamoException):
|
class UncapturedHigherOrderOpError(TorchDynamoException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ except ModuleNotFoundError:
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._device
|
import torch.utils._device
|
||||||
from torch._dynamo.source import (
|
from torch._dynamo.source import (
|
||||||
|
is_from_flatten_script_object_source,
|
||||||
is_from_local_source,
|
is_from_local_source,
|
||||||
is_from_optimizer_source,
|
is_from_optimizer_source,
|
||||||
TensorProperty,
|
TensorProperty,
|
||||||
|
|
@ -72,6 +73,7 @@ from .source import (
|
||||||
ChainedSource,
|
ChainedSource,
|
||||||
ConstDictKeySource,
|
ConstDictKeySource,
|
||||||
DefaultsSource,
|
DefaultsSource,
|
||||||
|
FlattenScriptObjectSource,
|
||||||
FSDPNNModuleSource,
|
FSDPNNModuleSource,
|
||||||
GetItemSource,
|
GetItemSource,
|
||||||
GlobalSource,
|
GlobalSource,
|
||||||
|
|
@ -84,6 +86,7 @@ from .source import (
|
||||||
NumpyTensorSource,
|
NumpyTensorSource,
|
||||||
ODictGetItemSource,
|
ODictGetItemSource,
|
||||||
OptimizerSource,
|
OptimizerSource,
|
||||||
|
ScriptObjectQualifiedNameSource,
|
||||||
ShapeEnvSource,
|
ShapeEnvSource,
|
||||||
TupleIteratorGetItemSource,
|
TupleIteratorGetItemSource,
|
||||||
TypeSource,
|
TypeSource,
|
||||||
|
|
@ -957,6 +960,22 @@ class GuardBuilder(GuardBuilderBase):
|
||||||
example_value=example_value,
|
example_value=example_value,
|
||||||
guard_manager_enum=guard_manager_enum,
|
guard_manager_enum=guard_manager_enum,
|
||||||
)
|
)
|
||||||
|
elif istype(source, FlattenScriptObjectSource):
|
||||||
|
assert base_guard_manager # to make mypy happy
|
||||||
|
return base_guard_manager.lambda_manager(
|
||||||
|
python_lambda=lambda x: x.__obj_flatten__(),
|
||||||
|
source=source_name,
|
||||||
|
example_value=example_value,
|
||||||
|
guard_manager_enum=guard_manager_enum,
|
||||||
|
)
|
||||||
|
elif istype(source, ScriptObjectQualifiedNameSource):
|
||||||
|
assert base_guard_manager # to make mypy happy
|
||||||
|
return base_guard_manager.lambda_manager(
|
||||||
|
python_lambda=lambda x: x._type().qualified_name(),
|
||||||
|
source=source_name,
|
||||||
|
example_value=example_value,
|
||||||
|
guard_manager_enum=guard_manager_enum,
|
||||||
|
)
|
||||||
elif istype(source, TupleIteratorGetItemSource):
|
elif istype(source, TupleIteratorGetItemSource):
|
||||||
assert base_guard_manager # to make mypy happy
|
assert base_guard_manager # to make mypy happy
|
||||||
return base_guard_manager.tuple_iterator_getitem_manager(
|
return base_guard_manager.tuple_iterator_getitem_manager(
|
||||||
|
|
@ -2602,6 +2621,14 @@ def make_dupe_guard(obj_source, dupe_source):
|
||||||
if dupe_source and dupe_source != obj_source:
|
if dupe_source and dupe_source != obj_source:
|
||||||
ser_source_is_local = is_from_local_source(dupe_source)
|
ser_source_is_local = is_from_local_source(dupe_source)
|
||||||
source_is_local = is_from_local_source(obj_source)
|
source_is_local = is_from_local_source(obj_source)
|
||||||
|
if is_from_flatten_script_object_source(
|
||||||
|
dupe_source
|
||||||
|
) or is_from_flatten_script_object_source(obj_source):
|
||||||
|
raise exc.UnsafeScriptObjectError(
|
||||||
|
f"{obj_source.name()} is alising {dupe_source.name()}. This is not supported."
|
||||||
|
f" Please do a clone for corresponding input."
|
||||||
|
)
|
||||||
|
|
||||||
# Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently
|
# Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently
|
||||||
# reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here,
|
# reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here,
|
||||||
# so maybe we should do this refactor before we land this...
|
# so maybe we should do this refactor before we land this...
|
||||||
|
|
|
||||||
|
|
@ -1434,7 +1434,11 @@ class OutputGraph:
|
||||||
self.remove_node(node)
|
self.remove_node(node)
|
||||||
self.real_value_cache.pop(node, None)
|
self.real_value_cache.pop(node, None)
|
||||||
|
|
||||||
used_symbols = set()
|
used_symbols: Set[sympy.Symbol] = set()
|
||||||
|
|
||||||
|
def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]):
|
||||||
|
used_symbols |= free_symbols(fake)
|
||||||
|
|
||||||
recheck_placeholders = []
|
recheck_placeholders = []
|
||||||
for node in self.placeholders:
|
for node in self.placeholders:
|
||||||
binds_symbol = placeholder_binds_symbol(node) is not None
|
binds_symbol = placeholder_binds_symbol(node) is not None
|
||||||
|
|
@ -1452,10 +1456,22 @@ class OutputGraph:
|
||||||
arg = node.meta["grapharg"]
|
arg = node.meta["grapharg"]
|
||||||
if isinstance(arg, BackwardStateGraphArg):
|
if isinstance(arg, BackwardStateGraphArg):
|
||||||
continue
|
continue
|
||||||
|
if isinstance(node.meta["grapharg"].example, torch.ScriptObject):
|
||||||
|
real_script_obj = node.meta["grapharg"].example
|
||||||
|
fake_script_obj = node.meta["grapharg"].example_strong_ref
|
||||||
|
flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined]
|
||||||
|
for attr in flat_dict.keys():
|
||||||
|
fake_attr_val = getattr(fake_script_obj.wrapped_obj, attr)
|
||||||
|
pytree.tree_map_only(
|
||||||
|
(torch.SymInt, torch.Tensor),
|
||||||
|
lambda t: update_used_symbols(used_symbols, t),
|
||||||
|
fake_attr_val,
|
||||||
|
)
|
||||||
|
continue
|
||||||
fake = (
|
fake = (
|
||||||
arg.fake_tensor if arg.fake_tensor is not None else arg.example
|
arg.fake_tensor if arg.fake_tensor is not None else arg.example
|
||||||
)
|
)
|
||||||
used_symbols |= free_symbols(fake)
|
update_used_symbols(used_symbols, fake)
|
||||||
|
|
||||||
# After removing unused graphargs, prune unused binds_symbol
|
# After removing unused graphargs, prune unused binds_symbol
|
||||||
for node in recheck_placeholders:
|
for node in recheck_placeholders:
|
||||||
|
|
|
||||||
|
|
@ -299,6 +299,36 @@ class ConvertIntSource(ChainedSource):
|
||||||
return f"cast_symbool_to_symint_guardless({self.base.name()})"
|
return f"cast_symbool_to_symint_guardless({self.base.name()})"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class FlattenScriptObjectSource(ChainedSource):
|
||||||
|
def __post_init__(self):
|
||||||
|
assert self.base is not None
|
||||||
|
|
||||||
|
def reconstruct(self, codegen):
|
||||||
|
self.base.reconstruct(codegen)
|
||||||
|
|
||||||
|
def guard_source(self):
|
||||||
|
return self.base.guard_source()
|
||||||
|
|
||||||
|
def name(self):
|
||||||
|
return f"{self.base.name()}.__obj_flatten__()"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class ScriptObjectQualifiedNameSource(ChainedSource):
|
||||||
|
def __post_init__(self):
|
||||||
|
assert self.base is not None
|
||||||
|
|
||||||
|
def reconstruct(self, codegen):
|
||||||
|
self.base.reconstruct(codegen)
|
||||||
|
|
||||||
|
def guard_source(self):
|
||||||
|
return self.base.guard_source()
|
||||||
|
|
||||||
|
def name(self):
|
||||||
|
return f"{self.base.name()}._type().qualified_name()"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class DefaultsSource(ChainedSource):
|
class DefaultsSource(ChainedSource):
|
||||||
idx_key: Union[int, str]
|
idx_key: Union[int, str]
|
||||||
|
|
@ -563,6 +593,14 @@ def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_from_flatten_script_object_source(source: Source):
|
||||||
|
if isinstance(source, FlattenScriptObjectSource):
|
||||||
|
return True
|
||||||
|
elif isinstance(source, ChainedSource):
|
||||||
|
return is_from_flatten_script_object_source(source.base)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_from_optimizer_source(source: Source):
|
def is_from_optimizer_source(source: Source):
|
||||||
if isinstance(source, OptimizerSource):
|
if isinstance(source, OptimizerSource):
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ import operator
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from typing import List, NamedTuple, Optional, Union
|
from typing import Any, List, NamedTuple, Optional, Union
|
||||||
|
|
||||||
from torch.utils._sympy.value_ranges import ValueRanges
|
from torch.utils._sympy.value_ranges import ValueRanges
|
||||||
|
|
||||||
|
|
@ -26,6 +26,7 @@ import torch
|
||||||
|
|
||||||
from torch import SymInt
|
from torch import SymInt
|
||||||
from torch._guards import GuardSource, TracingContext
|
from torch._guards import GuardSource, TracingContext
|
||||||
|
from torch._higher_order_ops.torchbind import call_torchbind
|
||||||
from torch._ops import HigherOrderOperator
|
from torch._ops import HigherOrderOperator
|
||||||
from torch._streambase import _EventBase, _StreamBase
|
from torch._streambase import _EventBase, _StreamBase
|
||||||
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
|
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
|
||||||
|
|
@ -154,6 +155,7 @@ from .misc import (
|
||||||
)
|
)
|
||||||
from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
|
from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
|
||||||
from .optimizer import OptimizerVariable
|
from .optimizer import OptimizerVariable
|
||||||
|
from .script_object import TorchScriptObjectVariable
|
||||||
|
|
||||||
from .sdpa import SDPAParamsVariable
|
from .sdpa import SDPAParamsVariable
|
||||||
from .tensor import (
|
from .tensor import (
|
||||||
|
|
@ -904,13 +906,65 @@ class VariableBuilder:
|
||||||
user_cls_source=AttrSource(self.source, "__class__"),
|
user_cls_source=AttrSource(self.source, "__class__"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
elif TorchScriptObjectVariable.is_matching_cls(type(value)):
|
||||||
|
from ..source import (
|
||||||
|
FlattenScriptObjectSource,
|
||||||
|
ScriptObjectQualifiedNameSource,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This exists to allow a smoother transition.
|
||||||
|
# The implications are:
|
||||||
|
# The script objects won't be tracked as proxies.
|
||||||
|
# Methods on these objects won't show up in the graph.
|
||||||
|
# The original script object might be mutated.
|
||||||
|
if not hasattr(value, "__obj_flatten__"):
|
||||||
|
return self.wrap_user_defined(value)
|
||||||
|
|
||||||
|
# Install the guards on the fully qualified name of the script object
|
||||||
|
LazyVariableTracker.realize_all(
|
||||||
|
VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))(
|
||||||
|
value._type().qualified_name() # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Install the guards on the content of the script object by setting the source
|
||||||
|
# to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
|
||||||
|
LazyVariableTracker.realize_all(
|
||||||
|
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
|
||||||
|
value.__obj_flatten__()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fake_script_obj = torch._library.fake_class_registry.to_fake_obj(
|
||||||
|
self.tx.output.fake_mode, value
|
||||||
|
)
|
||||||
|
|
||||||
|
proxy = self.tx.output.root_tracer.create_graph_input(
|
||||||
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
|
||||||
|
type(value),
|
||||||
|
source=self.source,
|
||||||
|
)
|
||||||
|
|
||||||
|
# setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default
|
||||||
|
# seting example to be real value because these example values will be used
|
||||||
|
# as example_inputs for user compiler.
|
||||||
|
proxy.node.meta["grapharg"] = GraphArg(
|
||||||
|
self.source, value, False, None, False, fake_script_obj
|
||||||
|
)
|
||||||
|
return TorchScriptObjectVariable.create(
|
||||||
|
proxy,
|
||||||
|
fake_script_obj,
|
||||||
|
source=self.source,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
return self.wrap_user_defined(value)
|
||||||
result = UserDefinedObjectVariable(value, source=self.source)
|
|
||||||
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
def wrap_user_defined(self, value: Any):
|
||||||
# don't allow STORE_ATTR mutation with custom __setattr__
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
return result
|
result = UserDefinedObjectVariable(value, source=self.source)
|
||||||
return self.tx.output.side_effects.track_object_existing(value, result)
|
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
||||||
|
# don't allow STORE_ATTR mutation with custom __setattr__
|
||||||
|
return result
|
||||||
|
return self.tx.output.side_effects.track_object_existing(value, result)
|
||||||
|
|
||||||
def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
|
def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
|
||||||
if config.specialize_int and type(value) is torch.Size:
|
if config.specialize_int and type(value) is torch.Size:
|
||||||
|
|
@ -1782,6 +1836,12 @@ def wrap_fx_proxy_cls(
|
||||||
]:
|
]:
|
||||||
set_example_value(proxy.node, example_value)
|
set_example_value(proxy.node, example_value)
|
||||||
return ConstantVariable.create(example_value, **options)
|
return ConstantVariable.create(example_value, **options)
|
||||||
|
elif (
|
||||||
|
isinstance(example_value, (int, float, bool))
|
||||||
|
and proxy.node.target is call_torchbind
|
||||||
|
):
|
||||||
|
set_example_value(proxy.node, example_value)
|
||||||
|
return ConstantVariable.create(example_value, **options)
|
||||||
else:
|
else:
|
||||||
unimplemented(
|
unimplemented(
|
||||||
"torch.* op returned non-Tensor "
|
"torch.* op returned non-Tensor "
|
||||||
|
|
|
||||||
|
|
@ -549,6 +549,8 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
|
||||||
return StrictModeHigherOrderVariable(value, source, **kwargs)
|
return StrictModeHigherOrderVariable(value, source, **kwargs)
|
||||||
elif value.__name__ == "associative_scan":
|
elif value.__name__ == "associative_scan":
|
||||||
return AssociativeScanHigherOrderVariable(value, source, **kwargs)
|
return AssociativeScanHigherOrderVariable(value, source, **kwargs)
|
||||||
|
elif value.__name__ == "call_torchbind":
|
||||||
|
return CallTorchbindHigherOrderVariable(value, source, **kwargs)
|
||||||
else:
|
else:
|
||||||
unimplemented(f"HigherOrderOperator {value.__name__}")
|
unimplemented(f"HigherOrderOperator {value.__name__}")
|
||||||
|
|
||||||
|
|
@ -769,6 +771,34 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CallTorchbindHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
|
def __init__(self, hop, source, script_obj_var, method_name):
|
||||||
|
super().__init__(hop, source)
|
||||||
|
self.script_obj_var = script_obj_var
|
||||||
|
self.method_name = method_name
|
||||||
|
|
||||||
|
def call_function(
|
||||||
|
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
|
||||||
|
) -> VariableTracker:
|
||||||
|
from .builder import wrap_fx_proxy
|
||||||
|
|
||||||
|
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
|
||||||
|
|
||||||
|
args_proxy = [arg.as_proxy() for arg in args]
|
||||||
|
kwargs_proxy = {k: v.as_proxy() for k, v in kwargs.items()}
|
||||||
|
return wrap_fx_proxy(
|
||||||
|
tx=tx,
|
||||||
|
proxy=tx.output.create_proxy(
|
||||||
|
"call_function",
|
||||||
|
self.value,
|
||||||
|
args=tuple(
|
||||||
|
[self.script_obj_var.as_proxy(), self.method_name] + args_proxy
|
||||||
|
),
|
||||||
|
kwargs=kwargs_proxy,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
@raise_hard_error_if_graph_break(
|
@raise_hard_error_if_graph_break(
|
||||||
reason="while_loop doesn't work unless it is captured completely with torch.compile."
|
reason="while_loop doesn't work unless it is captured completely with torch.compile."
|
||||||
|
|
|
||||||
80
torch/_dynamo/variables/script_object.py
Normal file
80
torch/_dynamo/variables/script_object.py
Normal file
|
|
@ -0,0 +1,80 @@
|
||||||
|
import functools
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported
|
||||||
|
|
||||||
|
from .base import VariableTracker
|
||||||
|
from .user_defined import UserDefinedObjectVariable
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_hard_error_if_graph_break(reason):
|
||||||
|
def deco(fn):
|
||||||
|
@functools.wraps(fn)
|
||||||
|
def graph_break_as_hard_error(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
except Unsupported as e:
|
||||||
|
raise UnsafeScriptObjectError(e.msg) from e
|
||||||
|
|
||||||
|
return graph_break_as_hard_error
|
||||||
|
|
||||||
|
return deco
|
||||||
|
|
||||||
|
|
||||||
|
class TorchScriptObjectVariable(UserDefinedObjectVariable):
|
||||||
|
_fake_script_object_cache: Dict[int, "TorchScriptObjectVariable"] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_matching_cls(cls, user_cls: type):
|
||||||
|
return issubclass(user_cls, torch.ScriptObject)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(proxy, value, **options):
|
||||||
|
return TorchScriptObjectVariable(proxy, value, **options)
|
||||||
|
|
||||||
|
def __init__(self, proxy, value, source, **kwargs):
|
||||||
|
super().__init__(value, **kwargs)
|
||||||
|
self.proxy = proxy
|
||||||
|
self.proxy.node.meta["example_value"] = value
|
||||||
|
self.source = source
|
||||||
|
|
||||||
|
def as_proxy(self):
|
||||||
|
return self.proxy
|
||||||
|
|
||||||
|
@_raise_hard_error_if_graph_break(
|
||||||
|
"Dynamo cannot safely trace script object due to graph break."
|
||||||
|
)
|
||||||
|
def var_getattr(self, tx, name: str) -> VariableTracker:
|
||||||
|
from torch._higher_order_ops.torchbind import call_torchbind
|
||||||
|
from ..source import AttrSource
|
||||||
|
from .higher_order_ops import TorchHigherOrderOperatorVariable
|
||||||
|
|
||||||
|
method = getattr(self.value, name, None)
|
||||||
|
if method is None:
|
||||||
|
unimplemented(
|
||||||
|
f"FakeScriptObject doesn't define method {name}. Did you forget to implement it in the fake class?"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not callable(method):
|
||||||
|
unimplemented(
|
||||||
|
"Only method calls on TorchScript objects can be supported safely."
|
||||||
|
" Please use method calls instead of attribute access."
|
||||||
|
)
|
||||||
|
|
||||||
|
return TorchHigherOrderOperatorVariable.make(
|
||||||
|
call_torchbind,
|
||||||
|
source=AttrSource(self.source, name),
|
||||||
|
script_obj_var=self,
|
||||||
|
method_name=name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We only support method calls on script objects. Interpreting the bytecodes
|
||||||
|
# should go through var_getattr then call_function instead of call_method.
|
||||||
|
#
|
||||||
|
# However, it's possible for call_method to be used directly e.g. for __setattr__.
|
||||||
|
@_raise_hard_error_if_graph_break(
|
||||||
|
"Dynamo cannot safely trace script object due to graph break."
|
||||||
|
)
|
||||||
|
def call_method(self, tx, name, args, kwargs):
|
||||||
|
unimplemented(f"call method {name} on script object is not safe.")
|
||||||
|
|
@ -205,38 +205,3 @@ class AbstractImplCtx:
|
||||||
result, min=min, max=max
|
result, min=min, max=max
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def to_fake_tensor(self, tensor: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Creates a fake tensor from a concrete tensor. Note: this is not needed for register_fake.
|
|
||||||
|
|
||||||
This is useful for register_fake_class (which is necessary for torch.compile) for custom class.
|
|
||||||
Users need to implement a from_real method that takes a real custom object and creates a fake
|
|
||||||
custom object. Users can use this API to create fake tensors for the tensor states in the custom object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (torch.Tensor): A concrete tensor.
|
|
||||||
|
|
||||||
Example::
|
|
||||||
>>> import torch
|
|
||||||
>>> @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") # xdoctest: +SKIP
|
|
||||||
... class FakeTensorQueue:
|
|
||||||
... def __init__(self, q):
|
|
||||||
... self.queue = q
|
|
||||||
...
|
|
||||||
... @classmethod
|
|
||||||
... def from_real(cls, real_tq):
|
|
||||||
... ctx = torch.library.get_ctx()
|
|
||||||
... fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
|
|
||||||
... return cls(fake_queue)
|
|
||||||
...
|
|
||||||
... def push(self, x):
|
|
||||||
... self.queue.append(x)
|
|
||||||
...
|
|
||||||
... def pop(self):
|
|
||||||
... return self.queue.pop(0)
|
|
||||||
...
|
|
||||||
... def size(self):
|
|
||||||
... return len(self.queue)
|
|
||||||
"""
|
|
||||||
return self._fake_mode.from_tensor(tensor)
|
|
||||||
|
|
|
||||||
|
|
@ -65,8 +65,35 @@ class FakeClassRegistry:
|
||||||
global_fake_class_registry = FakeClassRegistry()
|
global_fake_class_registry = FakeClassRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: add this check at compile time for __obj_flatten__.
|
||||||
|
def _check_valid_flat_script_obj(flat_x):
|
||||||
|
if not isinstance(flat_x, tuple):
|
||||||
|
raise RuntimeError("Expect flat x to be a tuple.")
|
||||||
|
|
||||||
|
for tp in flat_x:
|
||||||
|
if not isinstance(tp, tuple):
|
||||||
|
raise RuntimeError("Expect flat x to be a tuple of tuples.")
|
||||||
|
|
||||||
|
if not len(tp) == 2 or not isinstance(tp[0], str):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Expect element of flat x to be a tuple of two elements with first element being a string"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def to_fake_obj(fake_mode, x: torch.ScriptObject) -> FakeScriptObject:
|
def to_fake_obj(fake_mode, x: torch.ScriptObject) -> FakeScriptObject:
|
||||||
fake_x = _fake_obj_from_real(fake_mode, x)
|
import torch.utils._pytree as pytree
|
||||||
|
|
||||||
|
flat_x = x.__obj_flatten__() # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
_check_valid_flat_script_obj(flat_x)
|
||||||
|
|
||||||
|
fake_flattened = pytree.tree_map_only(
|
||||||
|
torch.Tensor,
|
||||||
|
lambda t: fake_mode.from_tensor(t),
|
||||||
|
flat_x,
|
||||||
|
)
|
||||||
|
|
||||||
|
fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened)
|
||||||
|
|
||||||
def _call_torchbind(method_name):
|
def _call_torchbind(method_name):
|
||||||
from torch._higher_order_ops.torchbind import call_torchbind
|
from torch._higher_order_ops.torchbind import call_torchbind
|
||||||
|
|
@ -134,14 +161,13 @@ def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal]
|
||||||
|
|
||||||
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
|
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
|
||||||
class FakeTensorQueue:
|
class FakeTensorQueue:
|
||||||
def __init__(self, q):
|
def __init__(self, queue):
|
||||||
self.queue = q
|
self.queue = queue
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_real(cls, real_tq):
|
def __obj_unflatten__(cls, flattened_ctx):
|
||||||
ctx = torch.library.get_ctx()
|
ctx = {flattened_ctx[0]: flattened_ctx[1]}
|
||||||
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.clone_queue()]
|
return cls(**ctx)
|
||||||
return cls(fake_queue)
|
|
||||||
|
|
||||||
def push(self, x):
|
def push(self, x):
|
||||||
self.queue.append(x)
|
self.queue.append(x)
|
||||||
|
|
@ -162,7 +188,9 @@ def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal]
|
||||||
|
|
||||||
from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
|
from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
|
||||||
if not from_method:
|
if not from_method:
|
||||||
raise RuntimeError(f"{fake_class} doesn't define a classmethod from_real.")
|
raise RuntimeError(
|
||||||
|
f"{fake_class} doesn't define a classmethod {_CONVERT_FROM_REAL_NAME}."
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod):
|
if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|
@ -221,7 +249,7 @@ def _find_fake_class_for_script_object(x: torch.ScriptObject) -> Any:
|
||||||
return fake_class
|
return fake_class
|
||||||
|
|
||||||
|
|
||||||
_CONVERT_FROM_REAL_NAME = "from_real"
|
_CONVERT_FROM_REAL_NAME = "__obj_unflatten__"
|
||||||
|
|
||||||
|
|
||||||
def _fake_obj_from_real(fake_mode, x) -> Any:
|
def _fake_obj_from_real(fake_mode, x) -> Any:
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,7 @@ hop_that_doesnt_have_opinfo_test_allowlist = [
|
||||||
"with_effects",
|
"with_effects",
|
||||||
"strict_mode",
|
"strict_mode",
|
||||||
"_export_tracepoint",
|
"_export_tracepoint",
|
||||||
|
"call_torchbind",
|
||||||
]
|
]
|
||||||
|
|
||||||
torch.library.define(
|
torch.library.define(
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,35 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
_TORCHBIND_IMPLS_INITIALIZED = False
|
_TORCHBIND_IMPLS_INITIALIZED = False
|
||||||
|
|
||||||
|
_TENSOR_QUEUE_GLOBAL_TEST: Optional[torch.ScriptObject] = None
|
||||||
|
|
||||||
|
|
||||||
def init_torchbind_implementations():
|
def init_torchbind_implementations():
|
||||||
global _TORCHBIND_IMPLS_INITIALIZED
|
global _TORCHBIND_IMPLS_INITIALIZED
|
||||||
|
global _TENSOR_QUEUE_GLOBAL_TEST
|
||||||
if _TORCHBIND_IMPLS_INITIALIZED:
|
if _TORCHBIND_IMPLS_INITIALIZED:
|
||||||
return
|
return
|
||||||
|
|
||||||
load_torchbind_test_lib()
|
load_torchbind_test_lib()
|
||||||
register_fake_operators()
|
register_fake_operators()
|
||||||
register_fake_classes()
|
register_fake_classes()
|
||||||
|
_TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue()
|
||||||
_TORCHBIND_IMPLS_INITIALIZED = True
|
_TORCHBIND_IMPLS_INITIALIZED = True
|
||||||
|
|
||||||
|
|
||||||
|
def _empty_tensor_queue() -> torch.ScriptObject:
|
||||||
|
return torch.classes._TorchScriptTesting._TensorQueue(
|
||||||
|
torch.empty(
|
||||||
|
0,
|
||||||
|
).fill_(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# put these under a function because the corresponding library might not be loaded yet.
|
# put these under a function because the corresponding library might not be loaded yet.
|
||||||
def register_fake_operators():
|
def register_fake_operators():
|
||||||
@torch.library.register_fake("_TorchScriptTesting::takes_foo_python_meta")
|
@torch.library.register_fake("_TorchScriptTesting::takes_foo_python_meta")
|
||||||
|
|
@ -67,25 +80,23 @@ def register_fake_classes():
|
||||||
self.y = y
|
self.y = y
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_real(cls, foo):
|
def __obj_unflatten__(cls, flattend_foo):
|
||||||
(x, y), _ = foo.__getstate__()
|
return cls(**dict(flattend_foo))
|
||||||
return cls(x, y)
|
|
||||||
|
|
||||||
def add_tensor(self, z):
|
def add_tensor(self, z):
|
||||||
return (self.x + self.y) * z
|
return (self.x + self.y) * z
|
||||||
|
|
||||||
@torch._library.register_fake_class("_TorchScriptTesting::_ContainsTensor")
|
@torch._library.register_fake_class("_TorchScriptTesting::_ContainsTensor")
|
||||||
class FakeContainsTensor:
|
class FakeContainsTensor:
|
||||||
def __init__(self, x: torch.Tensor):
|
def __init__(self, t: torch.Tensor):
|
||||||
self.x = x
|
self.t = t
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_real(cls, foo):
|
def __obj_unflatten__(cls, flattend_foo):
|
||||||
ctx = torch.library.get_ctx()
|
return cls(**dict(flattend_foo))
|
||||||
return cls(ctx.to_fake_tensor(foo.get()))
|
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
return self.x
|
return self.t
|
||||||
|
|
||||||
|
|
||||||
def load_torchbind_test_lib():
|
def load_torchbind_test_lib():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user