mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] support torchbind object input (#124978)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124978 Approved by: https://github.com/jansel
This commit is contained in:
parent
c165a8e71d
commit
461ffaaaf3
|
|
@ -59,6 +59,10 @@ struct Foo : torch::CustomClassHolder {
|
|||
bool eq(c10::intrusive_ptr<Foo> other) {
|
||||
return this->x == other->x && this->y == other->y;
|
||||
}
|
||||
std::tuple<std::tuple<std::string, int64_t>, std::tuple<std::string, int64_t>>
|
||||
__obj_flatten__() {
|
||||
return std::tuple(std::tuple("x", this->x), std::tuple("y", this->y));
|
||||
}
|
||||
};
|
||||
|
||||
struct _StaticMethod : torch::CustomClassHolder {
|
||||
|
|
@ -199,6 +203,10 @@ struct TensorQueue : torch::CustomClassHolder {
|
|||
return raw_queue;
|
||||
}
|
||||
|
||||
std::tuple<std::tuple<std::string, std::vector<at::Tensor>>> __obj_flatten__() {
|
||||
return std::tuple(std::tuple("queue", this->get_raw_queue()));
|
||||
}
|
||||
|
||||
private:
|
||||
std::deque<at::Tensor> queue_;
|
||||
std::mutex mutex_;
|
||||
|
|
@ -370,6 +378,10 @@ struct ContainsTensor : public torch::CustomClassHolder {
|
|||
return t_;
|
||||
}
|
||||
|
||||
std::tuple<std::tuple<std::string, at::Tensor>> __obj_flatten__() {
|
||||
return std::tuple(std::tuple("t", this->t_));
|
||||
}
|
||||
|
||||
at::Tensor t_;
|
||||
};
|
||||
|
||||
|
|
@ -417,6 +429,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
|
|||
.def("add_tensor", &Foo::add_tensor)
|
||||
.def("__eq__", &Foo::eq)
|
||||
.def("combine", &Foo::combine)
|
||||
.def("__obj_flatten__", &Foo::__obj_flatten__)
|
||||
.def_pickle(
|
||||
[](c10::intrusive_ptr<Foo> self) { // __getstate__
|
||||
return std::vector<int64_t>{self->x, self->y};
|
||||
|
|
@ -424,6 +437,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
|
|||
[](std::vector<int64_t> state) { // __setstate__
|
||||
return c10::make_intrusive<Foo>(state[0], state[1]);
|
||||
});
|
||||
|
||||
m.def(
|
||||
"takes_foo(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
|
||||
m.def(
|
||||
|
|
@ -551,6 +565,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
|
|||
m.class_<ContainsTensor>("_ContainsTensor")
|
||||
.def(torch::init<at::Tensor>())
|
||||
.def("get", &ContainsTensor::get)
|
||||
.def("__obj_flatten__", &ContainsTensor::__obj_flatten__)
|
||||
.def_pickle(
|
||||
// __getstate__
|
||||
[](const c10::intrusive_ptr<ContainsTensor>& self) -> at::Tensor {
|
||||
|
|
@ -568,6 +583,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
|
|||
.def("size", &TensorQueue::size)
|
||||
.def("clone_queue", &TensorQueue::clone_queue)
|
||||
.def("get_raw_queue", &TensorQueue::get_raw_queue)
|
||||
.def("__obj_flatten__", &TensorQueue::__obj_flatten__)
|
||||
.def_pickle(
|
||||
// __getstate__
|
||||
[](const c10::intrusive_ptr<TensorQueue>& self)
|
||||
|
|
|
|||
|
|
@ -3,8 +3,11 @@
|
|||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
||||
|
||||
from torch._higher_order_ops.wrap import wrap
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch.export import export
|
||||
from torch.export._trace import _export
|
||||
|
|
@ -16,7 +19,39 @@ from torch.testing._internal.common_utils import (
|
|||
skipIfTorchDynamo,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
|
||||
from torch.testing._internal.torchbind_impls import (
|
||||
_empty_tensor_queue,
|
||||
init_torchbind_implementations,
|
||||
)
|
||||
|
||||
|
||||
def _assertEqualSkipScriptObject(test_case, exp, actual):
|
||||
flat_exp = pytree.tree_leaves(exp)
|
||||
flat_actual = pytree.tree_leaves(actual)
|
||||
test_case.assertEqual(len(flat_exp), len(flat_actual))
|
||||
for a, b in zip(flat_exp, flat_actual):
|
||||
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
|
||||
continue
|
||||
test_case.assertEqual(a, b)
|
||||
|
||||
|
||||
def _check_script_obj_equal(test_case, a: torch.ScriptObject, b: torch.ScriptObject):
|
||||
return test_case.assertEqual(
|
||||
a._type().qualified_name(), b._type().qualified_name()
|
||||
) and test_case.assertEqual(a.__obj_flatten__(), b.__obj_flatten__())
|
||||
|
||||
|
||||
def _assertEqualScriptObject(
|
||||
test_case, exp, actual, check_obj_eq=_check_script_obj_equal
|
||||
):
|
||||
flat_exp = pytree.tree_leaves(exp)
|
||||
flat_actual = pytree.tree_leaves(actual)
|
||||
test_case.assertEqual(len(flat_exp), len(flat_actual))
|
||||
for a, b in zip(flat_exp, flat_actual):
|
||||
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
|
||||
check_obj_eq(test_case, a, b)
|
||||
else:
|
||||
test_case.assertEqual(a, b)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
|
||||
|
|
@ -37,9 +72,8 @@ class TestExportTorchbind(TestCase):
|
|||
self.y = y
|
||||
|
||||
@classmethod
|
||||
def from_real(cls, foo):
|
||||
(x, y), _ = foo.__getstate__()
|
||||
return cls(x, y)
|
||||
def __obj_unflatten__(cls, flattend_foo):
|
||||
return cls(**dict(flattend_foo))
|
||||
|
||||
def add_tensor(self, z):
|
||||
test.foo_add_tensor_counter += 1
|
||||
|
|
@ -47,14 +81,12 @@ class TestExportTorchbind(TestCase):
|
|||
|
||||
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
|
||||
class FakeTensorQueue:
|
||||
def __init__(self, q):
|
||||
self.queue = q
|
||||
def __init__(self, queue):
|
||||
self.queue = queue
|
||||
|
||||
@classmethod
|
||||
def from_real(cls, real_tq):
|
||||
ctx = torch.library.get_ctx()
|
||||
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
|
||||
return cls(fake_queue)
|
||||
def __obj_unflatten__(cls, flattened_ctx):
|
||||
return cls(**dict(flattened_ctx))
|
||||
|
||||
def push(self, x):
|
||||
test.tq_push_counter += 1
|
||||
|
|
@ -89,15 +121,6 @@ class TestExportTorchbind(TestCase):
|
|||
"_TorchScriptTesting::_TensorQueue"
|
||||
)
|
||||
|
||||
def _assertEqualSkipScriptObject(self, exp, actual):
|
||||
flat_exp = pytree.tree_leaves(exp)
|
||||
flat_actual = pytree.tree_leaves(actual)
|
||||
self.assertEqual(len(flat_exp), len(flat_actual))
|
||||
for a, b in zip(flat_exp, flat_actual):
|
||||
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
|
||||
continue
|
||||
self.assertEqual(a, b)
|
||||
|
||||
def _test_export_same_as_eager(
|
||||
self, f, args, kwargs=None, strict=True, pre_dispatch=False
|
||||
):
|
||||
|
|
@ -532,7 +555,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
""",
|
||||
)
|
||||
mod.check_tq_is_fake = False
|
||||
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
|
||||
_assertEqualSkipScriptObject(self, gm(tq, x), mod(tq1, x))
|
||||
|
||||
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
|
||||
def test_make_fx_tensor_queue_methods_fakify_internal_states(
|
||||
|
|
@ -592,7 +615,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
)
|
||||
# turn off tq type checking in eager execution
|
||||
mod.check_tq_is_fake = False
|
||||
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
|
||||
_assertEqualSkipScriptObject(self, gm(tq, x), mod(tq1, x))
|
||||
self.assertEqual(tq.size(), 0)
|
||||
self.assertEqual(tq1.size(), 0)
|
||||
|
||||
|
|
@ -769,7 +792,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
add = torch.ops.aten.add.Tensor(queue_pop_1, 0); queue_pop_1 = None
|
||||
return (sub, add, arg0_1)""",
|
||||
)
|
||||
self._assertEqualSkipScriptObject(gm(tq1, x), mod(tq2, x))
|
||||
_assertEqualSkipScriptObject(self, gm(tq1, x), mod(tq2, x))
|
||||
|
||||
def test_aot_export_tensor_queue_operators(self):
|
||||
class Model(torch.nn.Module):
|
||||
|
|
@ -829,6 +852,377 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
)
|
||||
|
||||
|
||||
class TestCompileTorchbind(TestCase):
|
||||
def setUp(self):
|
||||
init_torchbind_implementations()
|
||||
|
||||
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
|
||||
class FakeTensorQueue:
|
||||
def __init__(self, queue):
|
||||
self.queue = queue
|
||||
|
||||
@classmethod
|
||||
def __obj_unflatten__(cls, flattened_ctx):
|
||||
return cls(**dict(flattened_ctx))
|
||||
|
||||
def push(self, x):
|
||||
self.queue.append(x)
|
||||
|
||||
def pop(self):
|
||||
return self.queue.pop(0)
|
||||
|
||||
def size(self):
|
||||
return len(self.queue)
|
||||
|
||||
torch._dynamo.reset()
|
||||
|
||||
def tearDown(self):
|
||||
torch._dynamo.reset()
|
||||
|
||||
def test_compile_script_object_input(self):
|
||||
backend = EagerAndRecordGraphs()
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.check_tq_is_fake = True
|
||||
|
||||
def forward(self, tq, x):
|
||||
tq.push(x.cos())
|
||||
tq.push(x.sin())
|
||||
x_sin = tq.pop() - tq.size()
|
||||
return x_sin, tq
|
||||
|
||||
mod = Model()
|
||||
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
|
||||
torch.empty(
|
||||
0,
|
||||
).fill_(-1)
|
||||
)
|
||||
tq2 = torch.classes._TorchScriptTesting._TensorQueue(
|
||||
torch.empty(
|
||||
0,
|
||||
).fill_(-1)
|
||||
)
|
||||
tq3 = torch.classes._TorchScriptTesting._TensorQueue(
|
||||
torch.empty(
|
||||
0,
|
||||
).fill_(-1)
|
||||
)
|
||||
tq4 = torch.classes._TorchScriptTesting._TensorQueue(
|
||||
torch.empty(
|
||||
0,
|
||||
).fill_(-1)
|
||||
)
|
||||
x = torch.randn(2, 3)
|
||||
ret = torch.compile(mod, backend=backend)(tq1, x)
|
||||
eager_ret = mod(tq2, x)
|
||||
_assertEqualSkipScriptObject(self, ret, eager_ret)
|
||||
self.assertEqual(ret[1].size(), eager_ret[1].size())
|
||||
self.assertEqual(ret[1].pop(), eager_ret[1].pop())
|
||||
# Note that dynamo captured graph
|
||||
# does not return L_tq_ as output. This is because it's able
|
||||
# to detect that L_tq_ is an input therefore don't return
|
||||
# it as graph output. Related logic is in dynamo/codegen.py
|
||||
self.assertExpectedInline(
|
||||
backend.graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, L_tq_ : torch.ScriptObject, L_x_ : torch.Tensor):
|
||||
l_tq_ = L_tq_
|
||||
l_x_ = L_x_
|
||||
cos = l_x_.cos()
|
||||
call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', cos); cos = None
|
||||
sin = l_x_.sin(); l_x_ = None
|
||||
call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', sin); sin = None
|
||||
call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'pop')
|
||||
call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None
|
||||
x_sin = call_torchbind_2 - 1; call_torchbind_2 = None
|
||||
return (x_sin,)""",
|
||||
)
|
||||
|
||||
def test_compile_script_object_input_guards(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.check_tq_is_fake = True
|
||||
|
||||
def forward(self, tq, x):
|
||||
tq.push(x.cos())
|
||||
tq.push(x.sin())
|
||||
x_sin = tq.pop() - tq.size()
|
||||
return x_sin, tq
|
||||
|
||||
mod = Model()
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
x = torch.randn(2, 3)
|
||||
|
||||
tq1 = _empty_tensor_queue()
|
||||
torch.compile(mod, backend=cnt)(tq1, x)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
tq2 = _empty_tensor_queue()
|
||||
for _ in range(10):
|
||||
tq2.push(torch.randn(4, 5, requires_grad=False))
|
||||
torch.compile(mod, backend=cnt)(tq2, x)
|
||||
# Queue length change causes re-compile
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
tq3 = _empty_tensor_queue()
|
||||
tq3.push(torch.randn(2, 3, requires_grad=False))
|
||||
torch.compile(mod, backend=cnt)(tq3, x)
|
||||
# Tensor in queue changes shape causes re-compile
|
||||
self.assertEqual(cnt.frame_count, 3)
|
||||
|
||||
tq4 = _empty_tensor_queue()
|
||||
tq4.push(torch.randn(2, 3, requires_grad=False))
|
||||
torch.compile(mod, backend=cnt)(tq4, x)
|
||||
# No recompile
|
||||
self.assertEqual(cnt.frame_count, 3)
|
||||
|
||||
tq5 = _empty_tensor_queue()
|
||||
tq5.push(torch.randn(2, 3, requires_grad=True))
|
||||
torch.compile(mod, backend=cnt)(tq5, x)
|
||||
# Tensor in queue changes dispatch key causes re-compile
|
||||
self.assertEqual(cnt.frame_count, 4)
|
||||
|
||||
tq6 = _empty_tensor_queue()
|
||||
tq6.push(torch.randn(2, 3, requires_grad=True, dtype=torch.float64))
|
||||
torch.compile(mod, backend=cnt)(tq6, x)
|
||||
# Tensor in queue changes dtype causes re-compile
|
||||
self.assertEqual(cnt.frame_count, 5)
|
||||
|
||||
def test_compile_script_object_input_automatic_dynamic_shape(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.check_tq_is_fake = True
|
||||
|
||||
def forward(self, tq, x):
|
||||
tq.push(x.cos())
|
||||
tq.push(x.sin())
|
||||
x_sin = tq.pop() - tq.size()
|
||||
return x_sin, tq
|
||||
|
||||
mod = Model()
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
x = torch.randn(2, 3)
|
||||
|
||||
tq1 = _empty_tensor_queue()
|
||||
tq1.push(torch.randn(2, 3, requires_grad=False))
|
||||
torch.compile(mod, backend=cnt)(tq1, x)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
tq2 = _empty_tensor_queue()
|
||||
# make first tensor's secon dim dynamic
|
||||
tq2.push(torch.randn(2, 4, requires_grad=False))
|
||||
torch.compile(mod, backend=cnt)(tq2, x)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
tq3 = _empty_tensor_queue()
|
||||
tq3.push(torch.randn(2, 5, requires_grad=False))
|
||||
# should have no-recompilation
|
||||
torch.compile(mod, backend=cnt)(tq3, x)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_compile_error_on_input_aliasing_contents(self):
|
||||
backend = EagerAndRecordGraphs()
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.check_tq_is_fake = True
|
||||
|
||||
def forward(self, tq, x):
|
||||
return x.sin(), tq.pop().cos()
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
mod = Model()
|
||||
|
||||
tq1 = _empty_tensor_queue()
|
||||
tq1.push(x)
|
||||
with self.assertRaisesRegex(RuntimeError, "is alising"):
|
||||
torch.compile(mod, backend=backend)(tq1, x)
|
||||
|
||||
def test_compile_error_on_script_obj_setattr(self):
|
||||
def setattr_f(tq):
|
||||
tq.a = 1
|
||||
return tq
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "call method __setattr__ on script object is not safe"
|
||||
):
|
||||
torch.compile(setattr_f, backend="eager")(_empty_tensor_queue())
|
||||
|
||||
def test_compile_error_on_script_obj_missing_attr(self):
|
||||
def setattr_f(tq):
|
||||
return tq._not_defined_attr
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "doesn't define method _not_defined_attr"
|
||||
):
|
||||
torch.compile(setattr_f, backend="eager")(_empty_tensor_queue())
|
||||
|
||||
def test_compile_body_aliasing_contents(self):
|
||||
backend = EagerAndRecordGraphs()
|
||||
|
||||
def f(tq, x):
|
||||
x1 = x.view(-1)
|
||||
x2 = x.permute(1, 0)
|
||||
tq.push(x1)
|
||||
tq.push(x2)
|
||||
return x1 - tq.size(), x2 + tq.size(), tq
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
_assertEqualScriptObject(
|
||||
self,
|
||||
f(_empty_tensor_queue(), x),
|
||||
torch.compile(f, backend=backend)(_empty_tensor_queue(), x),
|
||||
)
|
||||
if not torch._dynamo.is_compiling():
|
||||
self.assertExpectedInline(
|
||||
backend.graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject):
|
||||
l_x_ = L_x_
|
||||
l_tq_ = L_tq_
|
||||
x1 = l_x_.view(-1)
|
||||
x2 = l_x_.permute(1, 0); l_x_ = None
|
||||
call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', x1)
|
||||
call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', x2)
|
||||
call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'size')
|
||||
sub = x1 - 2; x1 = None
|
||||
call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None
|
||||
add = x2 + 2; x2 = None
|
||||
return (sub, add)""",
|
||||
)
|
||||
|
||||
def test_compile_error_on_non_fakified_method(self):
|
||||
backend = EagerAndRecordGraphs()
|
||||
|
||||
def f(tq, x):
|
||||
x1 = x.view(-1)
|
||||
x2 = x.permute(1, 0)
|
||||
tq.push(x1)
|
||||
tq.push(x2)
|
||||
# though real tensor queue implemented a method clone_queue,
|
||||
# The fakified version doesn't.
|
||||
flat_obj = tq.clone_queue()
|
||||
return flat_obj
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "FakeScriptObject doesn't define method"
|
||||
):
|
||||
torch.compile(f, backend=backend)(_empty_tensor_queue(), x)
|
||||
|
||||
def test_compile_obj_as_hop_input(self):
|
||||
def f(tq, x):
|
||||
def fn(tq, x):
|
||||
tq.push(x)
|
||||
return x.sin()
|
||||
|
||||
return wrap(fn, tq, x)
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
_assertEqualScriptObject(
|
||||
self,
|
||||
f(_empty_tensor_queue(), x),
|
||||
torch.compile(f, backend="eager")(_empty_tensor_queue(), x),
|
||||
)
|
||||
|
||||
def test_compile_obj_closure(self):
|
||||
def f(x):
|
||||
def inner_f(x):
|
||||
tq.push(x.sin())
|
||||
|
||||
inner_f(x)
|
||||
return tq.pop(), tq
|
||||
|
||||
opt_f = torch.compile(f, backend="eager")
|
||||
|
||||
tq = _empty_tensor_queue()
|
||||
x = torch.randn(3, 2)
|
||||
_assertEqualScriptObject(self, f(x), opt_f(x))
|
||||
|
||||
def test_compile_global_obj(self):
|
||||
global _TENSOR_QUEUE_GLOBAL_TEST
|
||||
_TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue()
|
||||
|
||||
def f(x):
|
||||
_TENSOR_QUEUE_GLOBAL_TEST.push(x.sin())
|
||||
return _TENSOR_QUEUE_GLOBAL_TEST.pop(), _TENSOR_QUEUE_GLOBAL_TEST
|
||||
|
||||
opt_f = torch.compile(f, backend="eager")
|
||||
x = torch.randn(3, 2)
|
||||
eager_ret = f(x)
|
||||
opt_ret = opt_f(x)
|
||||
_assertEqualScriptObject(self, eager_ret, opt_ret)
|
||||
|
||||
def test_compile_obj_graph_breaks(self):
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
def f(tq, x):
|
||||
tq.push(x.sin())
|
||||
tq.push(x.sin())
|
||||
torch._dynamo.graph_break()
|
||||
tq.pop()
|
||||
torch._dynamo.graph_break()
|
||||
tq.push(x.cos() + tq.size())
|
||||
torch._dynamo.graph_break()
|
||||
tq.push(x.cos() - tq.size())
|
||||
return x, tq.pop(), tq
|
||||
|
||||
opt_f = torch.compile(f, backend=cnt)
|
||||
x = torch.randn(3, 2)
|
||||
_assertEqualScriptObject(
|
||||
self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x)
|
||||
)
|
||||
self.assertEqual(cnt.frame_count, 4)
|
||||
|
||||
def test_compile_obj_attributes(self):
|
||||
backend = EagerAndRecordGraphs()
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tq = _empty_tensor_queue()
|
||||
|
||||
def forward(self, x):
|
||||
self.tq.push(x)
|
||||
return self.tq.pop()
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
opt_f = torch.compile(Model(), backend=backend)
|
||||
_assertEqualScriptObject(self, Model()(x), opt_f(x))
|
||||
self.assertEqual(len(backend.graphs), 1)
|
||||
# lifted as input. In the future, we would want to cosolidate this
|
||||
# with non-strict behavior, where they're set as attributes.
|
||||
self.assertExpectedInline(
|
||||
backend.graphs[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, L_self_tq : torch.ScriptObject, L_x_ : torch.Tensor):
|
||||
l_self_tq = L_self_tq
|
||||
l_x_ = L_x_
|
||||
call_torchbind = torch.ops.higher_order.call_torchbind(l_self_tq, 'push', l_x_); l_x_ = None
|
||||
call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_self_tq, 'pop'); l_self_tq = None
|
||||
return (call_torchbind_1,)""",
|
||||
)
|
||||
|
||||
def test_compile_obj_torchbind_op(self):
|
||||
def f(tq, x):
|
||||
torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
|
||||
torch.ops._TorchScriptTesting.queue_push(tq, x.cos() + 1)
|
||||
torch.ops._TorchScriptTesting.queue_pop(tq)
|
||||
torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
|
||||
return tq.pop(), tq.pop() + tq.size(), tq
|
||||
|
||||
opt_f = torch.compile(f, backend="eager")
|
||||
x = torch.randn(2)
|
||||
_assertEqualScriptObject(
|
||||
self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x)
|
||||
)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
|
||||
class TestRegisterFakeClass(TestCase):
|
||||
def setUp(self):
|
||||
|
|
@ -845,7 +1239,9 @@ class TestRegisterFakeClass(TestCase):
|
|||
pass
|
||||
|
||||
def test_register_fake_class_no_from_real(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "define a classmethod from_real"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "define a classmethod __obj_unflatten__"
|
||||
):
|
||||
|
||||
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
||||
class InvalidFakeFoo:
|
||||
|
|
@ -861,9 +1257,8 @@ class TestRegisterFakeClass(TestCase):
|
|||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def from_real(self, foo_obj):
|
||||
x, y = foo_obj.__getstate__()
|
||||
return FakeFoo(x, y)
|
||||
def __obj_unflatten__(cls, flattend_foo): # noqa: B902
|
||||
return cls(**dict(flattend_foo))
|
||||
|
||||
def test_register_fake_class_valid(self):
|
||||
class FakeFoo:
|
||||
|
|
@ -872,9 +1267,8 @@ class TestRegisterFakeClass(TestCase):
|
|||
self.y = y
|
||||
|
||||
@classmethod
|
||||
def from_real(cls, foo_obj):
|
||||
x, y = foo_obj.__getstate__()
|
||||
return cls(x, y)
|
||||
def __obj_unflatten__(cls, flattend_foo):
|
||||
return cls(**dict(flattend_foo))
|
||||
|
||||
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
|
||||
|
||||
|
|
|
|||
|
|
@ -573,9 +573,8 @@ class TestUnflatten(TestCase):
|
|||
self.y = y
|
||||
|
||||
@classmethod
|
||||
def from_real(cls, foo):
|
||||
(x, y), _ = foo.__getstate__()
|
||||
return cls(x, y)
|
||||
def __obj_unflatten__(cls, flat_ctx):
|
||||
return cls(**dict(flat_ctx))
|
||||
|
||||
def add_tensor(self, z):
|
||||
return (self.x + self.y) * z
|
||||
|
|
|
|||
|
|
@ -17,34 +17,21 @@ from torch.testing import FileCheck
|
|||
from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM80OrLater
|
||||
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
|
||||
from torch.testing._internal.common_utils import (
|
||||
find_library_location,
|
||||
IS_FBCODE,
|
||||
IS_MACOS,
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
TEST_CUDA,
|
||||
TEST_WITH_ROCM,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
|
||||
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
class TestWithEffects(TestCase):
|
||||
def setUp(self):
|
||||
if IS_MACOS:
|
||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
||||
elif IS_SANDCASTLE or IS_FBCODE:
|
||||
torch.ops.load_library(
|
||||
"//caffe2/test/cpp/jit:test_custom_class_registrations"
|
||||
)
|
||||
elif IS_WINDOWS:
|
||||
lib_file_path = find_library_location("torchbind_test.dll")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
else:
|
||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
init_torchbind_implementations()
|
||||
|
||||
def test_print(self):
|
||||
class M(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -171,6 +171,10 @@ class UserStopIteration(TorchDynamoException):
|
|||
self.value = None
|
||||
|
||||
|
||||
class UnsafeScriptObjectError(TorchDynamoException):
|
||||
pass
|
||||
|
||||
|
||||
class UncapturedHigherOrderOpError(TorchDynamoException):
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ except ModuleNotFoundError:
|
|||
import torch
|
||||
import torch.utils._device
|
||||
from torch._dynamo.source import (
|
||||
is_from_flatten_script_object_source,
|
||||
is_from_local_source,
|
||||
is_from_optimizer_source,
|
||||
TensorProperty,
|
||||
|
|
@ -72,6 +73,7 @@ from .source import (
|
|||
ChainedSource,
|
||||
ConstDictKeySource,
|
||||
DefaultsSource,
|
||||
FlattenScriptObjectSource,
|
||||
FSDPNNModuleSource,
|
||||
GetItemSource,
|
||||
GlobalSource,
|
||||
|
|
@ -84,6 +86,7 @@ from .source import (
|
|||
NumpyTensorSource,
|
||||
ODictGetItemSource,
|
||||
OptimizerSource,
|
||||
ScriptObjectQualifiedNameSource,
|
||||
ShapeEnvSource,
|
||||
TupleIteratorGetItemSource,
|
||||
TypeSource,
|
||||
|
|
@ -957,6 +960,22 @@ class GuardBuilder(GuardBuilderBase):
|
|||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, FlattenScriptObjectSource):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
return base_guard_manager.lambda_manager(
|
||||
python_lambda=lambda x: x.__obj_flatten__(),
|
||||
source=source_name,
|
||||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, ScriptObjectQualifiedNameSource):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
return base_guard_manager.lambda_manager(
|
||||
python_lambda=lambda x: x._type().qualified_name(),
|
||||
source=source_name,
|
||||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, TupleIteratorGetItemSource):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
return base_guard_manager.tuple_iterator_getitem_manager(
|
||||
|
|
@ -2602,6 +2621,14 @@ def make_dupe_guard(obj_source, dupe_source):
|
|||
if dupe_source and dupe_source != obj_source:
|
||||
ser_source_is_local = is_from_local_source(dupe_source)
|
||||
source_is_local = is_from_local_source(obj_source)
|
||||
if is_from_flatten_script_object_source(
|
||||
dupe_source
|
||||
) or is_from_flatten_script_object_source(obj_source):
|
||||
raise exc.UnsafeScriptObjectError(
|
||||
f"{obj_source.name()} is alising {dupe_source.name()}. This is not supported."
|
||||
f" Please do a clone for corresponding input."
|
||||
)
|
||||
|
||||
# Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently
|
||||
# reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here,
|
||||
# so maybe we should do this refactor before we land this...
|
||||
|
|
|
|||
|
|
@ -1434,7 +1434,11 @@ class OutputGraph:
|
|||
self.remove_node(node)
|
||||
self.real_value_cache.pop(node, None)
|
||||
|
||||
used_symbols = set()
|
||||
used_symbols: Set[sympy.Symbol] = set()
|
||||
|
||||
def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]):
|
||||
used_symbols |= free_symbols(fake)
|
||||
|
||||
recheck_placeholders = []
|
||||
for node in self.placeholders:
|
||||
binds_symbol = placeholder_binds_symbol(node) is not None
|
||||
|
|
@ -1452,10 +1456,22 @@ class OutputGraph:
|
|||
arg = node.meta["grapharg"]
|
||||
if isinstance(arg, BackwardStateGraphArg):
|
||||
continue
|
||||
if isinstance(node.meta["grapharg"].example, torch.ScriptObject):
|
||||
real_script_obj = node.meta["grapharg"].example
|
||||
fake_script_obj = node.meta["grapharg"].example_strong_ref
|
||||
flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined]
|
||||
for attr in flat_dict.keys():
|
||||
fake_attr_val = getattr(fake_script_obj.wrapped_obj, attr)
|
||||
pytree.tree_map_only(
|
||||
(torch.SymInt, torch.Tensor),
|
||||
lambda t: update_used_symbols(used_symbols, t),
|
||||
fake_attr_val,
|
||||
)
|
||||
continue
|
||||
fake = (
|
||||
arg.fake_tensor if arg.fake_tensor is not None else arg.example
|
||||
)
|
||||
used_symbols |= free_symbols(fake)
|
||||
update_used_symbols(used_symbols, fake)
|
||||
|
||||
# After removing unused graphargs, prune unused binds_symbol
|
||||
for node in recheck_placeholders:
|
||||
|
|
|
|||
|
|
@ -299,6 +299,36 @@ class ConvertIntSource(ChainedSource):
|
|||
return f"cast_symbool_to_symint_guardless({self.base.name()})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FlattenScriptObjectSource(ChainedSource):
|
||||
def __post_init__(self):
|
||||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
self.base.reconstruct(codegen)
|
||||
|
||||
def guard_source(self):
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
return f"{self.base.name()}.__obj_flatten__()"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ScriptObjectQualifiedNameSource(ChainedSource):
|
||||
def __post_init__(self):
|
||||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
self.base.reconstruct(codegen)
|
||||
|
||||
def guard_source(self):
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
return f"{self.base.name()}._type().qualified_name()"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DefaultsSource(ChainedSource):
|
||||
idx_key: Union[int, str]
|
||||
|
|
@ -563,6 +593,14 @@ def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
|
|||
return True
|
||||
|
||||
|
||||
def is_from_flatten_script_object_source(source: Source):
|
||||
if isinstance(source, FlattenScriptObjectSource):
|
||||
return True
|
||||
elif isinstance(source, ChainedSource):
|
||||
return is_from_flatten_script_object_source(source.base)
|
||||
return False
|
||||
|
||||
|
||||
def is_from_optimizer_source(source: Source):
|
||||
if isinstance(source, OptimizerSource):
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import operator
|
|||
import re
|
||||
import sys
|
||||
import types
|
||||
from typing import List, NamedTuple, Optional, Union
|
||||
from typing import Any, List, NamedTuple, Optional, Union
|
||||
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
|
|
@ -26,6 +26,7 @@ import torch
|
|||
|
||||
from torch import SymInt
|
||||
from torch._guards import GuardSource, TracingContext
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._streambase import _EventBase, _StreamBase
|
||||
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
|
||||
|
|
@ -154,6 +155,7 @@ from .misc import (
|
|||
)
|
||||
from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
|
||||
from .optimizer import OptimizerVariable
|
||||
from .script_object import TorchScriptObjectVariable
|
||||
|
||||
from .sdpa import SDPAParamsVariable
|
||||
from .tensor import (
|
||||
|
|
@ -904,13 +906,65 @@ class VariableBuilder:
|
|||
user_cls_source=AttrSource(self.source, "__class__"),
|
||||
),
|
||||
)
|
||||
elif TorchScriptObjectVariable.is_matching_cls(type(value)):
|
||||
from ..source import (
|
||||
FlattenScriptObjectSource,
|
||||
ScriptObjectQualifiedNameSource,
|
||||
)
|
||||
|
||||
# This exists to allow a smoother transition.
|
||||
# The implications are:
|
||||
# The script objects won't be tracked as proxies.
|
||||
# Methods on these objects won't show up in the graph.
|
||||
# The original script object might be mutated.
|
||||
if not hasattr(value, "__obj_flatten__"):
|
||||
return self.wrap_user_defined(value)
|
||||
|
||||
# Install the guards on the fully qualified name of the script object
|
||||
LazyVariableTracker.realize_all(
|
||||
VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))(
|
||||
value._type().qualified_name() # type: ignore[attr-defined]
|
||||
)
|
||||
)
|
||||
# Install the guards on the content of the script object by setting the source
|
||||
# to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
|
||||
LazyVariableTracker.realize_all(
|
||||
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
|
||||
value.__obj_flatten__()
|
||||
)
|
||||
)
|
||||
|
||||
fake_script_obj = torch._library.fake_class_registry.to_fake_obj(
|
||||
self.tx.output.fake_mode, value
|
||||
)
|
||||
|
||||
proxy = self.tx.output.root_tracer.create_graph_input(
|
||||
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
|
||||
type(value),
|
||||
source=self.source,
|
||||
)
|
||||
|
||||
# setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default
|
||||
# seting example to be real value because these example values will be used
|
||||
# as example_inputs for user compiler.
|
||||
proxy.node.meta["grapharg"] = GraphArg(
|
||||
self.source, value, False, None, False, fake_script_obj
|
||||
)
|
||||
return TorchScriptObjectVariable.create(
|
||||
proxy,
|
||||
fake_script_obj,
|
||||
source=self.source,
|
||||
)
|
||||
else:
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
result = UserDefinedObjectVariable(value, source=self.source)
|
||||
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
||||
# don't allow STORE_ATTR mutation with custom __setattr__
|
||||
return result
|
||||
return self.tx.output.side_effects.track_object_existing(value, result)
|
||||
return self.wrap_user_defined(value)
|
||||
|
||||
def wrap_user_defined(self, value: Any):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
result = UserDefinedObjectVariable(value, source=self.source)
|
||||
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
||||
# don't allow STORE_ATTR mutation with custom __setattr__
|
||||
return result
|
||||
return self.tx.output.side_effects.track_object_existing(value, result)
|
||||
|
||||
def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
|
||||
if config.specialize_int and type(value) is torch.Size:
|
||||
|
|
@ -1782,6 +1836,12 @@ def wrap_fx_proxy_cls(
|
|||
]:
|
||||
set_example_value(proxy.node, example_value)
|
||||
return ConstantVariable.create(example_value, **options)
|
||||
elif (
|
||||
isinstance(example_value, (int, float, bool))
|
||||
and proxy.node.target is call_torchbind
|
||||
):
|
||||
set_example_value(proxy.node, example_value)
|
||||
return ConstantVariable.create(example_value, **options)
|
||||
else:
|
||||
unimplemented(
|
||||
"torch.* op returned non-Tensor "
|
||||
|
|
|
|||
|
|
@ -549,6 +549,8 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
|
|||
return StrictModeHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "associative_scan":
|
||||
return AssociativeScanHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "call_torchbind":
|
||||
return CallTorchbindHigherOrderVariable(value, source, **kwargs)
|
||||
else:
|
||||
unimplemented(f"HigherOrderOperator {value.__name__}")
|
||||
|
||||
|
|
@ -769,6 +771,34 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
)
|
||||
|
||||
|
||||
class CallTorchbindHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
def __init__(self, hop, source, script_obj_var, method_name):
|
||||
super().__init__(hop, source)
|
||||
self.script_obj_var = script_obj_var
|
||||
self.method_name = method_name
|
||||
|
||||
def call_function(
|
||||
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
|
||||
) -> VariableTracker:
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
|
||||
|
||||
args_proxy = [arg.as_proxy() for arg in args]
|
||||
kwargs_proxy = {k: v.as_proxy() for k, v in kwargs.items()}
|
||||
return wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
args=tuple(
|
||||
[self.script_obj_var.as_proxy(), self.method_name] + args_proxy
|
||||
),
|
||||
kwargs=kwargs_proxy,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
@raise_hard_error_if_graph_break(
|
||||
reason="while_loop doesn't work unless it is captured completely with torch.compile."
|
||||
|
|
|
|||
80
torch/_dynamo/variables/script_object.py
Normal file
80
torch/_dynamo/variables/script_object.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
import functools
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported
|
||||
|
||||
from .base import VariableTracker
|
||||
from .user_defined import UserDefinedObjectVariable
|
||||
|
||||
|
||||
def _raise_hard_error_if_graph_break(reason):
|
||||
def deco(fn):
|
||||
@functools.wraps(fn)
|
||||
def graph_break_as_hard_error(*args, **kwargs):
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except Unsupported as e:
|
||||
raise UnsafeScriptObjectError(e.msg) from e
|
||||
|
||||
return graph_break_as_hard_error
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
class TorchScriptObjectVariable(UserDefinedObjectVariable):
|
||||
_fake_script_object_cache: Dict[int, "TorchScriptObjectVariable"] = {}
|
||||
|
||||
@classmethod
|
||||
def is_matching_cls(cls, user_cls: type):
|
||||
return issubclass(user_cls, torch.ScriptObject)
|
||||
|
||||
@staticmethod
|
||||
def create(proxy, value, **options):
|
||||
return TorchScriptObjectVariable(proxy, value, **options)
|
||||
|
||||
def __init__(self, proxy, value, source, **kwargs):
|
||||
super().__init__(value, **kwargs)
|
||||
self.proxy = proxy
|
||||
self.proxy.node.meta["example_value"] = value
|
||||
self.source = source
|
||||
|
||||
def as_proxy(self):
|
||||
return self.proxy
|
||||
|
||||
@_raise_hard_error_if_graph_break(
|
||||
"Dynamo cannot safely trace script object due to graph break."
|
||||
)
|
||||
def var_getattr(self, tx, name: str) -> VariableTracker:
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
from ..source import AttrSource
|
||||
from .higher_order_ops import TorchHigherOrderOperatorVariable
|
||||
|
||||
method = getattr(self.value, name, None)
|
||||
if method is None:
|
||||
unimplemented(
|
||||
f"FakeScriptObject doesn't define method {name}. Did you forget to implement it in the fake class?"
|
||||
)
|
||||
|
||||
if not callable(method):
|
||||
unimplemented(
|
||||
"Only method calls on TorchScript objects can be supported safely."
|
||||
" Please use method calls instead of attribute access."
|
||||
)
|
||||
|
||||
return TorchHigherOrderOperatorVariable.make(
|
||||
call_torchbind,
|
||||
source=AttrSource(self.source, name),
|
||||
script_obj_var=self,
|
||||
method_name=name,
|
||||
)
|
||||
|
||||
# We only support method calls on script objects. Interpreting the bytecodes
|
||||
# should go through var_getattr then call_function instead of call_method.
|
||||
#
|
||||
# However, it's possible for call_method to be used directly e.g. for __setattr__.
|
||||
@_raise_hard_error_if_graph_break(
|
||||
"Dynamo cannot safely trace script object due to graph break."
|
||||
)
|
||||
def call_method(self, tx, name, args, kwargs):
|
||||
unimplemented(f"call method {name} on script object is not safe.")
|
||||
|
|
@ -205,38 +205,3 @@ class AbstractImplCtx:
|
|||
result, min=min, max=max
|
||||
)
|
||||
return result
|
||||
|
||||
def to_fake_tensor(self, tensor: torch.Tensor):
|
||||
"""
|
||||
Creates a fake tensor from a concrete tensor. Note: this is not needed for register_fake.
|
||||
|
||||
This is useful for register_fake_class (which is necessary for torch.compile) for custom class.
|
||||
Users need to implement a from_real method that takes a real custom object and creates a fake
|
||||
custom object. Users can use this API to create fake tensors for the tensor states in the custom object.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): A concrete tensor.
|
||||
|
||||
Example::
|
||||
>>> import torch
|
||||
>>> @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") # xdoctest: +SKIP
|
||||
... class FakeTensorQueue:
|
||||
... def __init__(self, q):
|
||||
... self.queue = q
|
||||
...
|
||||
... @classmethod
|
||||
... def from_real(cls, real_tq):
|
||||
... ctx = torch.library.get_ctx()
|
||||
... fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
|
||||
... return cls(fake_queue)
|
||||
...
|
||||
... def push(self, x):
|
||||
... self.queue.append(x)
|
||||
...
|
||||
... def pop(self):
|
||||
... return self.queue.pop(0)
|
||||
...
|
||||
... def size(self):
|
||||
... return len(self.queue)
|
||||
"""
|
||||
return self._fake_mode.from_tensor(tensor)
|
||||
|
|
|
|||
|
|
@ -65,8 +65,35 @@ class FakeClassRegistry:
|
|||
global_fake_class_registry = FakeClassRegistry()
|
||||
|
||||
|
||||
# TODO: add this check at compile time for __obj_flatten__.
|
||||
def _check_valid_flat_script_obj(flat_x):
|
||||
if not isinstance(flat_x, tuple):
|
||||
raise RuntimeError("Expect flat x to be a tuple.")
|
||||
|
||||
for tp in flat_x:
|
||||
if not isinstance(tp, tuple):
|
||||
raise RuntimeError("Expect flat x to be a tuple of tuples.")
|
||||
|
||||
if not len(tp) == 2 or not isinstance(tp[0], str):
|
||||
raise RuntimeError(
|
||||
"Expect element of flat x to be a tuple of two elements with first element being a string"
|
||||
)
|
||||
|
||||
|
||||
def to_fake_obj(fake_mode, x: torch.ScriptObject) -> FakeScriptObject:
|
||||
fake_x = _fake_obj_from_real(fake_mode, x)
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
flat_x = x.__obj_flatten__() # type: ignore[attr-defined]
|
||||
|
||||
_check_valid_flat_script_obj(flat_x)
|
||||
|
||||
fake_flattened = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda t: fake_mode.from_tensor(t),
|
||||
flat_x,
|
||||
)
|
||||
|
||||
fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened)
|
||||
|
||||
def _call_torchbind(method_name):
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
|
|
@ -134,14 +161,13 @@ def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal]
|
|||
|
||||
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
|
||||
class FakeTensorQueue:
|
||||
def __init__(self, q):
|
||||
self.queue = q
|
||||
def __init__(self, queue):
|
||||
self.queue = queue
|
||||
|
||||
@classmethod
|
||||
def from_real(cls, real_tq):
|
||||
ctx = torch.library.get_ctx()
|
||||
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.clone_queue()]
|
||||
return cls(fake_queue)
|
||||
def __obj_unflatten__(cls, flattened_ctx):
|
||||
ctx = {flattened_ctx[0]: flattened_ctx[1]}
|
||||
return cls(**ctx)
|
||||
|
||||
def push(self, x):
|
||||
self.queue.append(x)
|
||||
|
|
@ -162,7 +188,9 @@ def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal]
|
|||
|
||||
from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
|
||||
if not from_method:
|
||||
raise RuntimeError(f"{fake_class} doesn't define a classmethod from_real.")
|
||||
raise RuntimeError(
|
||||
f"{fake_class} doesn't define a classmethod {_CONVERT_FROM_REAL_NAME}."
|
||||
)
|
||||
|
||||
if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod):
|
||||
raise RuntimeError(
|
||||
|
|
@ -221,7 +249,7 @@ def _find_fake_class_for_script_object(x: torch.ScriptObject) -> Any:
|
|||
return fake_class
|
||||
|
||||
|
||||
_CONVERT_FROM_REAL_NAME = "from_real"
|
||||
_CONVERT_FROM_REAL_NAME = "__obj_unflatten__"
|
||||
|
||||
|
||||
def _fake_obj_from_real(fake_mode, x) -> Any:
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ hop_that_doesnt_have_opinfo_test_allowlist = [
|
|||
"with_effects",
|
||||
"strict_mode",
|
||||
"_export_tracepoint",
|
||||
"call_torchbind",
|
||||
]
|
||||
|
||||
torch.library.define(
|
||||
|
|
|
|||
|
|
@ -1,22 +1,35 @@
|
|||
import contextlib
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
_TORCHBIND_IMPLS_INITIALIZED = False
|
||||
|
||||
_TENSOR_QUEUE_GLOBAL_TEST: Optional[torch.ScriptObject] = None
|
||||
|
||||
|
||||
def init_torchbind_implementations():
|
||||
global _TORCHBIND_IMPLS_INITIALIZED
|
||||
global _TENSOR_QUEUE_GLOBAL_TEST
|
||||
if _TORCHBIND_IMPLS_INITIALIZED:
|
||||
return
|
||||
|
||||
load_torchbind_test_lib()
|
||||
register_fake_operators()
|
||||
register_fake_classes()
|
||||
_TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue()
|
||||
_TORCHBIND_IMPLS_INITIALIZED = True
|
||||
|
||||
|
||||
def _empty_tensor_queue() -> torch.ScriptObject:
|
||||
return torch.classes._TorchScriptTesting._TensorQueue(
|
||||
torch.empty(
|
||||
0,
|
||||
).fill_(-1)
|
||||
)
|
||||
|
||||
|
||||
# put these under a function because the corresponding library might not be loaded yet.
|
||||
def register_fake_operators():
|
||||
@torch.library.register_fake("_TorchScriptTesting::takes_foo_python_meta")
|
||||
|
|
@ -67,25 +80,23 @@ def register_fake_classes():
|
|||
self.y = y
|
||||
|
||||
@classmethod
|
||||
def from_real(cls, foo):
|
||||
(x, y), _ = foo.__getstate__()
|
||||
return cls(x, y)
|
||||
def __obj_unflatten__(cls, flattend_foo):
|
||||
return cls(**dict(flattend_foo))
|
||||
|
||||
def add_tensor(self, z):
|
||||
return (self.x + self.y) * z
|
||||
|
||||
@torch._library.register_fake_class("_TorchScriptTesting::_ContainsTensor")
|
||||
class FakeContainsTensor:
|
||||
def __init__(self, x: torch.Tensor):
|
||||
self.x = x
|
||||
def __init__(self, t: torch.Tensor):
|
||||
self.t = t
|
||||
|
||||
@classmethod
|
||||
def from_real(cls, foo):
|
||||
ctx = torch.library.get_ctx()
|
||||
return cls(ctx.to_fake_tensor(foo.get()))
|
||||
def __obj_unflatten__(cls, flattend_foo):
|
||||
return cls(**dict(flattend_foo))
|
||||
|
||||
def get(self):
|
||||
return self.x
|
||||
return self.t
|
||||
|
||||
|
||||
def load_torchbind_test_lib():
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user