[dynamo] support torchbind object input (#124978)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124978
Approved by: https://github.com/jansel
This commit is contained in:
ydwu4 2024-05-06 15:03:02 -07:00 committed by PyTorch MergeBot
parent c165a8e71d
commit 461ffaaaf3
15 changed files with 766 additions and 110 deletions

View File

@ -59,6 +59,10 @@ struct Foo : torch::CustomClassHolder {
bool eq(c10::intrusive_ptr<Foo> other) {
return this->x == other->x && this->y == other->y;
}
std::tuple<std::tuple<std::string, int64_t>, std::tuple<std::string, int64_t>>
__obj_flatten__() {
return std::tuple(std::tuple("x", this->x), std::tuple("y", this->y));
}
};
struct _StaticMethod : torch::CustomClassHolder {
@ -199,6 +203,10 @@ struct TensorQueue : torch::CustomClassHolder {
return raw_queue;
}
std::tuple<std::tuple<std::string, std::vector<at::Tensor>>> __obj_flatten__() {
return std::tuple(std::tuple("queue", this->get_raw_queue()));
}
private:
std::deque<at::Tensor> queue_;
std::mutex mutex_;
@ -370,6 +378,10 @@ struct ContainsTensor : public torch::CustomClassHolder {
return t_;
}
std::tuple<std::tuple<std::string, at::Tensor>> __obj_flatten__() {
return std::tuple(std::tuple("t", this->t_));
}
at::Tensor t_;
};
@ -417,6 +429,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
.def("add_tensor", &Foo::add_tensor)
.def("__eq__", &Foo::eq)
.def("combine", &Foo::combine)
.def("__obj_flatten__", &Foo::__obj_flatten__)
.def_pickle(
[](c10::intrusive_ptr<Foo> self) { // __getstate__
return std::vector<int64_t>{self->x, self->y};
@ -424,6 +437,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
[](std::vector<int64_t> state) { // __setstate__
return c10::make_intrusive<Foo>(state[0], state[1]);
});
m.def(
"takes_foo(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
m.def(
@ -551,6 +565,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
m.class_<ContainsTensor>("_ContainsTensor")
.def(torch::init<at::Tensor>())
.def("get", &ContainsTensor::get)
.def("__obj_flatten__", &ContainsTensor::__obj_flatten__)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<ContainsTensor>& self) -> at::Tensor {
@ -568,6 +583,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
.def("size", &TensorQueue::size)
.def("clone_queue", &TensorQueue::clone_queue)
.def("get_raw_queue", &TensorQueue::get_raw_queue)
.def("__obj_flatten__", &TensorQueue::__obj_flatten__)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<TensorQueue>& self)

View File

@ -3,8 +3,11 @@
import torch
import torch.utils._pytree as pytree
from torch._dynamo.testing import EagerAndRecordGraphs
from torch._functorch.aot_autograd import aot_export_module
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch._higher_order_ops.wrap import wrap
from torch._library.fake_class_registry import FakeScriptObject
from torch.export import export
from torch.export._trace import _export
@ -16,7 +19,39 @@ from torch.testing._internal.common_utils import (
skipIfTorchDynamo,
TestCase,
)
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
from torch.testing._internal.torchbind_impls import (
_empty_tensor_queue,
init_torchbind_implementations,
)
def _assertEqualSkipScriptObject(test_case, exp, actual):
flat_exp = pytree.tree_leaves(exp)
flat_actual = pytree.tree_leaves(actual)
test_case.assertEqual(len(flat_exp), len(flat_actual))
for a, b in zip(flat_exp, flat_actual):
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
continue
test_case.assertEqual(a, b)
def _check_script_obj_equal(test_case, a: torch.ScriptObject, b: torch.ScriptObject):
return test_case.assertEqual(
a._type().qualified_name(), b._type().qualified_name()
) and test_case.assertEqual(a.__obj_flatten__(), b.__obj_flatten__())
def _assertEqualScriptObject(
test_case, exp, actual, check_obj_eq=_check_script_obj_equal
):
flat_exp = pytree.tree_leaves(exp)
flat_actual = pytree.tree_leaves(actual)
test_case.assertEqual(len(flat_exp), len(flat_actual))
for a, b in zip(flat_exp, flat_actual):
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
check_obj_eq(test_case, a, b)
else:
test_case.assertEqual(a, b)
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
@ -37,9 +72,8 @@ class TestExportTorchbind(TestCase):
self.y = y
@classmethod
def from_real(cls, foo):
(x, y), _ = foo.__getstate__()
return cls(x, y)
def __obj_unflatten__(cls, flattend_foo):
return cls(**dict(flattend_foo))
def add_tensor(self, z):
test.foo_add_tensor_counter += 1
@ -47,14 +81,12 @@ class TestExportTorchbind(TestCase):
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
class FakeTensorQueue:
def __init__(self, q):
self.queue = q
def __init__(self, queue):
self.queue = queue
@classmethod
def from_real(cls, real_tq):
ctx = torch.library.get_ctx()
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
return cls(fake_queue)
def __obj_unflatten__(cls, flattened_ctx):
return cls(**dict(flattened_ctx))
def push(self, x):
test.tq_push_counter += 1
@ -89,15 +121,6 @@ class TestExportTorchbind(TestCase):
"_TorchScriptTesting::_TensorQueue"
)
def _assertEqualSkipScriptObject(self, exp, actual):
flat_exp = pytree.tree_leaves(exp)
flat_actual = pytree.tree_leaves(actual)
self.assertEqual(len(flat_exp), len(flat_actual))
for a, b in zip(flat_exp, flat_actual):
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
continue
self.assertEqual(a, b)
def _test_export_same_as_eager(
self, f, args, kwargs=None, strict=True, pre_dispatch=False
):
@ -532,7 +555,7 @@ def forward(self, arg0_1, arg1_1):
""",
)
mod.check_tq_is_fake = False
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
_assertEqualSkipScriptObject(self, gm(tq, x), mod(tq1, x))
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
def test_make_fx_tensor_queue_methods_fakify_internal_states(
@ -592,7 +615,7 @@ def forward(self, arg0_1, arg1_1):
)
# turn off tq type checking in eager execution
mod.check_tq_is_fake = False
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
_assertEqualSkipScriptObject(self, gm(tq, x), mod(tq1, x))
self.assertEqual(tq.size(), 0)
self.assertEqual(tq1.size(), 0)
@ -769,7 +792,7 @@ def forward(self, arg0_1, arg1_1):
add = torch.ops.aten.add.Tensor(queue_pop_1, 0); queue_pop_1 = None
return (sub, add, arg0_1)""",
)
self._assertEqualSkipScriptObject(gm(tq1, x), mod(tq2, x))
_assertEqualSkipScriptObject(self, gm(tq1, x), mod(tq2, x))
def test_aot_export_tensor_queue_operators(self):
class Model(torch.nn.Module):
@ -829,6 +852,377 @@ def forward(self, arg0_1, arg1_1, arg2_1):
)
class TestCompileTorchbind(TestCase):
def setUp(self):
init_torchbind_implementations()
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
class FakeTensorQueue:
def __init__(self, queue):
self.queue = queue
@classmethod
def __obj_unflatten__(cls, flattened_ctx):
return cls(**dict(flattened_ctx))
def push(self, x):
self.queue.append(x)
def pop(self):
return self.queue.pop(0)
def size(self):
return len(self.queue)
torch._dynamo.reset()
def tearDown(self):
torch._dynamo.reset()
def test_compile_script_object_input(self):
backend = EagerAndRecordGraphs()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.check_tq_is_fake = True
def forward(self, tq, x):
tq.push(x.cos())
tq.push(x.sin())
x_sin = tq.pop() - tq.size()
return x_sin, tq
mod = Model()
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq2 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq3 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq4 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
x = torch.randn(2, 3)
ret = torch.compile(mod, backend=backend)(tq1, x)
eager_ret = mod(tq2, x)
_assertEqualSkipScriptObject(self, ret, eager_ret)
self.assertEqual(ret[1].size(), eager_ret[1].size())
self.assertEqual(ret[1].pop(), eager_ret[1].pop())
# Note that dynamo captured graph
# does not return L_tq_ as output. This is because it's able
# to detect that L_tq_ is an input therefore don't return
# it as graph output. Related logic is in dynamo/codegen.py
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, L_tq_ : torch.ScriptObject, L_x_ : torch.Tensor):
l_tq_ = L_tq_
l_x_ = L_x_
cos = l_x_.cos()
call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', cos); cos = None
sin = l_x_.sin(); l_x_ = None
call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', sin); sin = None
call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'pop')
call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None
x_sin = call_torchbind_2 - 1; call_torchbind_2 = None
return (x_sin,)""",
)
def test_compile_script_object_input_guards(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.check_tq_is_fake = True
def forward(self, tq, x):
tq.push(x.cos())
tq.push(x.sin())
x_sin = tq.pop() - tq.size()
return x_sin, tq
mod = Model()
cnt = torch._dynamo.testing.CompileCounter()
x = torch.randn(2, 3)
tq1 = _empty_tensor_queue()
torch.compile(mod, backend=cnt)(tq1, x)
self.assertEqual(cnt.frame_count, 1)
tq2 = _empty_tensor_queue()
for _ in range(10):
tq2.push(torch.randn(4, 5, requires_grad=False))
torch.compile(mod, backend=cnt)(tq2, x)
# Queue length change causes re-compile
self.assertEqual(cnt.frame_count, 2)
tq3 = _empty_tensor_queue()
tq3.push(torch.randn(2, 3, requires_grad=False))
torch.compile(mod, backend=cnt)(tq3, x)
# Tensor in queue changes shape causes re-compile
self.assertEqual(cnt.frame_count, 3)
tq4 = _empty_tensor_queue()
tq4.push(torch.randn(2, 3, requires_grad=False))
torch.compile(mod, backend=cnt)(tq4, x)
# No recompile
self.assertEqual(cnt.frame_count, 3)
tq5 = _empty_tensor_queue()
tq5.push(torch.randn(2, 3, requires_grad=True))
torch.compile(mod, backend=cnt)(tq5, x)
# Tensor in queue changes dispatch key causes re-compile
self.assertEqual(cnt.frame_count, 4)
tq6 = _empty_tensor_queue()
tq6.push(torch.randn(2, 3, requires_grad=True, dtype=torch.float64))
torch.compile(mod, backend=cnt)(tq6, x)
# Tensor in queue changes dtype causes re-compile
self.assertEqual(cnt.frame_count, 5)
def test_compile_script_object_input_automatic_dynamic_shape(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.check_tq_is_fake = True
def forward(self, tq, x):
tq.push(x.cos())
tq.push(x.sin())
x_sin = tq.pop() - tq.size()
return x_sin, tq
mod = Model()
cnt = torch._dynamo.testing.CompileCounter()
x = torch.randn(2, 3)
tq1 = _empty_tensor_queue()
tq1.push(torch.randn(2, 3, requires_grad=False))
torch.compile(mod, backend=cnt)(tq1, x)
self.assertEqual(cnt.frame_count, 1)
tq2 = _empty_tensor_queue()
# make first tensor's secon dim dynamic
tq2.push(torch.randn(2, 4, requires_grad=False))
torch.compile(mod, backend=cnt)(tq2, x)
self.assertEqual(cnt.frame_count, 2)
tq3 = _empty_tensor_queue()
tq3.push(torch.randn(2, 5, requires_grad=False))
# should have no-recompilation
torch.compile(mod, backend=cnt)(tq3, x)
self.assertEqual(cnt.frame_count, 2)
def test_compile_error_on_input_aliasing_contents(self):
backend = EagerAndRecordGraphs()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.check_tq_is_fake = True
def forward(self, tq, x):
return x.sin(), tq.pop().cos()
x = torch.randn(2, 3)
mod = Model()
tq1 = _empty_tensor_queue()
tq1.push(x)
with self.assertRaisesRegex(RuntimeError, "is alising"):
torch.compile(mod, backend=backend)(tq1, x)
def test_compile_error_on_script_obj_setattr(self):
def setattr_f(tq):
tq.a = 1
return tq
with self.assertRaisesRegex(
RuntimeError, "call method __setattr__ on script object is not safe"
):
torch.compile(setattr_f, backend="eager")(_empty_tensor_queue())
def test_compile_error_on_script_obj_missing_attr(self):
def setattr_f(tq):
return tq._not_defined_attr
with self.assertRaisesRegex(
RuntimeError, "doesn't define method _not_defined_attr"
):
torch.compile(setattr_f, backend="eager")(_empty_tensor_queue())
def test_compile_body_aliasing_contents(self):
backend = EagerAndRecordGraphs()
def f(tq, x):
x1 = x.view(-1)
x2 = x.permute(1, 0)
tq.push(x1)
tq.push(x2)
return x1 - tq.size(), x2 + tq.size(), tq
x = torch.randn(2, 3)
_assertEqualScriptObject(
self,
f(_empty_tensor_queue(), x),
torch.compile(f, backend=backend)(_empty_tensor_queue(), x),
)
if not torch._dynamo.is_compiling():
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject):
l_x_ = L_x_
l_tq_ = L_tq_
x1 = l_x_.view(-1)
x2 = l_x_.permute(1, 0); l_x_ = None
call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', x1)
call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', x2)
call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'size')
sub = x1 - 2; x1 = None
call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None
add = x2 + 2; x2 = None
return (sub, add)""",
)
def test_compile_error_on_non_fakified_method(self):
backend = EagerAndRecordGraphs()
def f(tq, x):
x1 = x.view(-1)
x2 = x.permute(1, 0)
tq.push(x1)
tq.push(x2)
# though real tensor queue implemented a method clone_queue,
# The fakified version doesn't.
flat_obj = tq.clone_queue()
return flat_obj
x = torch.randn(2, 3)
with self.assertRaisesRegex(
RuntimeError, "FakeScriptObject doesn't define method"
):
torch.compile(f, backend=backend)(_empty_tensor_queue(), x)
def test_compile_obj_as_hop_input(self):
def f(tq, x):
def fn(tq, x):
tq.push(x)
return x.sin()
return wrap(fn, tq, x)
x = torch.randn(2, 3)
_assertEqualScriptObject(
self,
f(_empty_tensor_queue(), x),
torch.compile(f, backend="eager")(_empty_tensor_queue(), x),
)
def test_compile_obj_closure(self):
def f(x):
def inner_f(x):
tq.push(x.sin())
inner_f(x)
return tq.pop(), tq
opt_f = torch.compile(f, backend="eager")
tq = _empty_tensor_queue()
x = torch.randn(3, 2)
_assertEqualScriptObject(self, f(x), opt_f(x))
def test_compile_global_obj(self):
global _TENSOR_QUEUE_GLOBAL_TEST
_TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue()
def f(x):
_TENSOR_QUEUE_GLOBAL_TEST.push(x.sin())
return _TENSOR_QUEUE_GLOBAL_TEST.pop(), _TENSOR_QUEUE_GLOBAL_TEST
opt_f = torch.compile(f, backend="eager")
x = torch.randn(3, 2)
eager_ret = f(x)
opt_ret = opt_f(x)
_assertEqualScriptObject(self, eager_ret, opt_ret)
def test_compile_obj_graph_breaks(self):
cnt = torch._dynamo.testing.CompileCounter()
def f(tq, x):
tq.push(x.sin())
tq.push(x.sin())
torch._dynamo.graph_break()
tq.pop()
torch._dynamo.graph_break()
tq.push(x.cos() + tq.size())
torch._dynamo.graph_break()
tq.push(x.cos() - tq.size())
return x, tq.pop(), tq
opt_f = torch.compile(f, backend=cnt)
x = torch.randn(3, 2)
_assertEqualScriptObject(
self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x)
)
self.assertEqual(cnt.frame_count, 4)
def test_compile_obj_attributes(self):
backend = EagerAndRecordGraphs()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.tq = _empty_tensor_queue()
def forward(self, x):
self.tq.push(x)
return self.tq.pop()
x = torch.randn(2, 3)
opt_f = torch.compile(Model(), backend=backend)
_assertEqualScriptObject(self, Model()(x), opt_f(x))
self.assertEqual(len(backend.graphs), 1)
# lifted as input. In the future, we would want to cosolidate this
# with non-strict behavior, where they're set as attributes.
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, L_self_tq : torch.ScriptObject, L_x_ : torch.Tensor):
l_self_tq = L_self_tq
l_x_ = L_x_
call_torchbind = torch.ops.higher_order.call_torchbind(l_self_tq, 'push', l_x_); l_x_ = None
call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_self_tq, 'pop'); l_self_tq = None
return (call_torchbind_1,)""",
)
def test_compile_obj_torchbind_op(self):
def f(tq, x):
torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
torch.ops._TorchScriptTesting.queue_push(tq, x.cos() + 1)
torch.ops._TorchScriptTesting.queue_pop(tq)
torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
return tq.pop(), tq.pop() + tq.size(), tq
opt_f = torch.compile(f, backend="eager")
x = torch.randn(2)
_assertEqualScriptObject(
self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x)
)
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
class TestRegisterFakeClass(TestCase):
def setUp(self):
@ -845,7 +1239,9 @@ class TestRegisterFakeClass(TestCase):
pass
def test_register_fake_class_no_from_real(self):
with self.assertRaisesRegex(RuntimeError, "define a classmethod from_real"):
with self.assertRaisesRegex(
RuntimeError, "define a classmethod __obj_unflatten__"
):
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class InvalidFakeFoo:
@ -861,9 +1257,8 @@ class TestRegisterFakeClass(TestCase):
self.x = x
self.y = y
def from_real(self, foo_obj):
x, y = foo_obj.__getstate__()
return FakeFoo(x, y)
def __obj_unflatten__(cls, flattend_foo): # noqa: B902
return cls(**dict(flattend_foo))
def test_register_fake_class_valid(self):
class FakeFoo:
@ -872,9 +1267,8 @@ class TestRegisterFakeClass(TestCase):
self.y = y
@classmethod
def from_real(cls, foo_obj):
x, y = foo_obj.__getstate__()
return cls(x, y)
def __obj_unflatten__(cls, flattend_foo):
return cls(**dict(flattend_foo))
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)

View File

@ -573,9 +573,8 @@ class TestUnflatten(TestCase):
self.y = y
@classmethod
def from_real(cls, foo):
(x, y), _ = foo.__getstate__()
return cls(x, y)
def __obj_unflatten__(cls, flat_ctx):
return cls(**dict(flat_ctx))
def add_tensor(self, z):
return (self.x + self.y) * z

View File

@ -17,34 +17,21 @@ from torch.testing import FileCheck
from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM80OrLater
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
run_tests,
TEST_CUDA,
TEST_WITH_ROCM,
TestCase,
)
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
from torch.utils.hooks import RemovableHandle
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
class TestWithEffects(TestCase):
def setUp(self):
if IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test")
elif IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library(
"//caffe2/test/cpp/jit:test_custom_class_registrations"
)
elif IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
else:
lib_file_path = find_library_location("libtorchbind_test.so")
torch.ops.load_library(str(lib_file_path))
init_torchbind_implementations()
def test_print(self):
class M(torch.nn.Module):

View File

@ -171,6 +171,10 @@ class UserStopIteration(TorchDynamoException):
self.value = None
class UnsafeScriptObjectError(TorchDynamoException):
pass
class UncapturedHigherOrderOpError(TorchDynamoException):
pass

View File

@ -41,6 +41,7 @@ except ModuleNotFoundError:
import torch
import torch.utils._device
from torch._dynamo.source import (
is_from_flatten_script_object_source,
is_from_local_source,
is_from_optimizer_source,
TensorProperty,
@ -72,6 +73,7 @@ from .source import (
ChainedSource,
ConstDictKeySource,
DefaultsSource,
FlattenScriptObjectSource,
FSDPNNModuleSource,
GetItemSource,
GlobalSource,
@ -84,6 +86,7 @@ from .source import (
NumpyTensorSource,
ODictGetItemSource,
OptimizerSource,
ScriptObjectQualifiedNameSource,
ShapeEnvSource,
TupleIteratorGetItemSource,
TypeSource,
@ -957,6 +960,22 @@ class GuardBuilder(GuardBuilderBase):
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, FlattenScriptObjectSource):
assert base_guard_manager # to make mypy happy
return base_guard_manager.lambda_manager(
python_lambda=lambda x: x.__obj_flatten__(),
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, ScriptObjectQualifiedNameSource):
assert base_guard_manager # to make mypy happy
return base_guard_manager.lambda_manager(
python_lambda=lambda x: x._type().qualified_name(),
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, TupleIteratorGetItemSource):
assert base_guard_manager # to make mypy happy
return base_guard_manager.tuple_iterator_getitem_manager(
@ -2602,6 +2621,14 @@ def make_dupe_guard(obj_source, dupe_source):
if dupe_source and dupe_source != obj_source:
ser_source_is_local = is_from_local_source(dupe_source)
source_is_local = is_from_local_source(obj_source)
if is_from_flatten_script_object_source(
dupe_source
) or is_from_flatten_script_object_source(obj_source):
raise exc.UnsafeScriptObjectError(
f"{obj_source.name()} is alising {dupe_source.name()}. This is not supported."
f" Please do a clone for corresponding input."
)
# Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently
# reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here,
# so maybe we should do this refactor before we land this...

View File

@ -1434,7 +1434,11 @@ class OutputGraph:
self.remove_node(node)
self.real_value_cache.pop(node, None)
used_symbols = set()
used_symbols: Set[sympy.Symbol] = set()
def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]):
used_symbols |= free_symbols(fake)
recheck_placeholders = []
for node in self.placeholders:
binds_symbol = placeholder_binds_symbol(node) is not None
@ -1452,10 +1456,22 @@ class OutputGraph:
arg = node.meta["grapharg"]
if isinstance(arg, BackwardStateGraphArg):
continue
if isinstance(node.meta["grapharg"].example, torch.ScriptObject):
real_script_obj = node.meta["grapharg"].example
fake_script_obj = node.meta["grapharg"].example_strong_ref
flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined]
for attr in flat_dict.keys():
fake_attr_val = getattr(fake_script_obj.wrapped_obj, attr)
pytree.tree_map_only(
(torch.SymInt, torch.Tensor),
lambda t: update_used_symbols(used_symbols, t),
fake_attr_val,
)
continue
fake = (
arg.fake_tensor if arg.fake_tensor is not None else arg.example
)
used_symbols |= free_symbols(fake)
update_used_symbols(used_symbols, fake)
# After removing unused graphargs, prune unused binds_symbol
for node in recheck_placeholders:

View File

@ -299,6 +299,36 @@ class ConvertIntSource(ChainedSource):
return f"cast_symbool_to_symint_guardless({self.base.name()})"
@dataclasses.dataclass(frozen=True)
class FlattenScriptObjectSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
self.base.reconstruct(codegen)
def guard_source(self):
return self.base.guard_source()
def name(self):
return f"{self.base.name()}.__obj_flatten__()"
@dataclasses.dataclass(frozen=True)
class ScriptObjectQualifiedNameSource(ChainedSource):
def __post_init__(self):
assert self.base is not None
def reconstruct(self, codegen):
self.base.reconstruct(codegen)
def guard_source(self):
return self.base.guard_source()
def name(self):
return f"{self.base.name()}._type().qualified_name()"
@dataclasses.dataclass(frozen=True)
class DefaultsSource(ChainedSource):
idx_key: Union[int, str]
@ -563,6 +593,14 @@ def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
return True
def is_from_flatten_script_object_source(source: Source):
if isinstance(source, FlattenScriptObjectSource):
return True
elif isinstance(source, ChainedSource):
return is_from_flatten_script_object_source(source.base)
return False
def is_from_optimizer_source(source: Source):
if isinstance(source, OptimizerSource):
return True

View File

@ -13,7 +13,7 @@ import operator
import re
import sys
import types
from typing import List, NamedTuple, Optional, Union
from typing import Any, List, NamedTuple, Optional, Union
from torch.utils._sympy.value_ranges import ValueRanges
@ -26,6 +26,7 @@ import torch
from torch import SymInt
from torch._guards import GuardSource, TracingContext
from torch._higher_order_ops.torchbind import call_torchbind
from torch._ops import HigherOrderOperator
from torch._streambase import _EventBase, _StreamBase
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
@ -154,6 +155,7 @@ from .misc import (
)
from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
from .optimizer import OptimizerVariable
from .script_object import TorchScriptObjectVariable
from .sdpa import SDPAParamsVariable
from .tensor import (
@ -904,13 +906,65 @@ class VariableBuilder:
user_cls_source=AttrSource(self.source, "__class__"),
),
)
elif TorchScriptObjectVariable.is_matching_cls(type(value)):
from ..source import (
FlattenScriptObjectSource,
ScriptObjectQualifiedNameSource,
)
# This exists to allow a smoother transition.
# The implications are:
# The script objects won't be tracked as proxies.
# Methods on these objects won't show up in the graph.
# The original script object might be mutated.
if not hasattr(value, "__obj_flatten__"):
return self.wrap_user_defined(value)
# Install the guards on the fully qualified name of the script object
LazyVariableTracker.realize_all(
VariableBuilder(self.tx, ScriptObjectQualifiedNameSource(self.source))(
value._type().qualified_name() # type: ignore[attr-defined]
)
)
# Install the guards on the content of the script object by setting the source
# to be FlattenScriptObjectSource, which calls __obj_flatten__() to get the contents.
LazyVariableTracker.realize_all(
VariableBuilder(self.tx, FlattenScriptObjectSource(self.source))(
value.__obj_flatten__()
)
)
fake_script_obj = torch._library.fake_class_registry.to_fake_obj(
self.tx.output.fake_mode, value
)
proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(value),
source=self.source,
)
# setting is_unspecialized=False to not insert a as_tensor call in reconstruct by default
# seting example to be real value because these example values will be used
# as example_inputs for user compiler.
proxy.node.meta["grapharg"] = GraphArg(
self.source, value, False, None, False, fake_script_obj
)
return TorchScriptObjectVariable.create(
proxy,
fake_script_obj,
source=self.source,
)
else:
self.install_guards(GuardBuilder.TYPE_MATCH)
result = UserDefinedObjectVariable(value, source=self.source)
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__
return result
return self.tx.output.side_effects.track_object_existing(value, result)
return self.wrap_user_defined(value)
def wrap_user_defined(self, value: Any):
self.install_guards(GuardBuilder.TYPE_MATCH)
result = UserDefinedObjectVariable(value, source=self.source)
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
# don't allow STORE_ATTR mutation with custom __setattr__
return result
return self.tx.output.side_effects.track_object_existing(value, result)
def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
if config.specialize_int and type(value) is torch.Size:
@ -1782,6 +1836,12 @@ def wrap_fx_proxy_cls(
]:
set_example_value(proxy.node, example_value)
return ConstantVariable.create(example_value, **options)
elif (
isinstance(example_value, (int, float, bool))
and proxy.node.target is call_torchbind
):
set_example_value(proxy.node, example_value)
return ConstantVariable.create(example_value, **options)
else:
unimplemented(
"torch.* op returned non-Tensor "

View File

@ -549,6 +549,8 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
return StrictModeHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "associative_scan":
return AssociativeScanHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "call_torchbind":
return CallTorchbindHigherOrderVariable(value, source, **kwargs)
else:
unimplemented(f"HigherOrderOperator {value.__name__}")
@ -769,6 +771,34 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
class CallTorchbindHigherOrderVariable(TorchHigherOrderOperatorVariable):
def __init__(self, hop, source, script_obj_var, method_name):
super().__init__(hop, source)
self.script_obj_var = script_obj_var
self.method_name = method_name
def call_function(
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
from .builder import wrap_fx_proxy
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
args_proxy = [arg.as_proxy() for arg in args]
kwargs_proxy = {k: v.as_proxy() for k, v in kwargs.items()}
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=tuple(
[self.script_obj_var.as_proxy(), self.method_name] + args_proxy
),
kwargs=kwargs_proxy,
),
)
class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
@raise_hard_error_if_graph_break(
reason="while_loop doesn't work unless it is captured completely with torch.compile."

View File

@ -0,0 +1,80 @@
import functools
from typing import Dict
import torch
from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported
from .base import VariableTracker
from .user_defined import UserDefinedObjectVariable
def _raise_hard_error_if_graph_break(reason):
def deco(fn):
@functools.wraps(fn)
def graph_break_as_hard_error(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Unsupported as e:
raise UnsafeScriptObjectError(e.msg) from e
return graph_break_as_hard_error
return deco
class TorchScriptObjectVariable(UserDefinedObjectVariable):
_fake_script_object_cache: Dict[int, "TorchScriptObjectVariable"] = {}
@classmethod
def is_matching_cls(cls, user_cls: type):
return issubclass(user_cls, torch.ScriptObject)
@staticmethod
def create(proxy, value, **options):
return TorchScriptObjectVariable(proxy, value, **options)
def __init__(self, proxy, value, source, **kwargs):
super().__init__(value, **kwargs)
self.proxy = proxy
self.proxy.node.meta["example_value"] = value
self.source = source
def as_proxy(self):
return self.proxy
@_raise_hard_error_if_graph_break(
"Dynamo cannot safely trace script object due to graph break."
)
def var_getattr(self, tx, name: str) -> VariableTracker:
from torch._higher_order_ops.torchbind import call_torchbind
from ..source import AttrSource
from .higher_order_ops import TorchHigherOrderOperatorVariable
method = getattr(self.value, name, None)
if method is None:
unimplemented(
f"FakeScriptObject doesn't define method {name}. Did you forget to implement it in the fake class?"
)
if not callable(method):
unimplemented(
"Only method calls on TorchScript objects can be supported safely."
" Please use method calls instead of attribute access."
)
return TorchHigherOrderOperatorVariable.make(
call_torchbind,
source=AttrSource(self.source, name),
script_obj_var=self,
method_name=name,
)
# We only support method calls on script objects. Interpreting the bytecodes
# should go through var_getattr then call_function instead of call_method.
#
# However, it's possible for call_method to be used directly e.g. for __setattr__.
@_raise_hard_error_if_graph_break(
"Dynamo cannot safely trace script object due to graph break."
)
def call_method(self, tx, name, args, kwargs):
unimplemented(f"call method {name} on script object is not safe.")

View File

@ -205,38 +205,3 @@ class AbstractImplCtx:
result, min=min, max=max
)
return result
def to_fake_tensor(self, tensor: torch.Tensor):
"""
Creates a fake tensor from a concrete tensor. Note: this is not needed for register_fake.
This is useful for register_fake_class (which is necessary for torch.compile) for custom class.
Users need to implement a from_real method that takes a real custom object and creates a fake
custom object. Users can use this API to create fake tensors for the tensor states in the custom object.
Args:
tensor (torch.Tensor): A concrete tensor.
Example::
>>> import torch
>>> @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") # xdoctest: +SKIP
... class FakeTensorQueue:
... def __init__(self, q):
... self.queue = q
...
... @classmethod
... def from_real(cls, real_tq):
... ctx = torch.library.get_ctx()
... fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
... return cls(fake_queue)
...
... def push(self, x):
... self.queue.append(x)
...
... def pop(self):
... return self.queue.pop(0)
...
... def size(self):
... return len(self.queue)
"""
return self._fake_mode.from_tensor(tensor)

View File

@ -65,8 +65,35 @@ class FakeClassRegistry:
global_fake_class_registry = FakeClassRegistry()
# TODO: add this check at compile time for __obj_flatten__.
def _check_valid_flat_script_obj(flat_x):
if not isinstance(flat_x, tuple):
raise RuntimeError("Expect flat x to be a tuple.")
for tp in flat_x:
if not isinstance(tp, tuple):
raise RuntimeError("Expect flat x to be a tuple of tuples.")
if not len(tp) == 2 or not isinstance(tp[0], str):
raise RuntimeError(
"Expect element of flat x to be a tuple of two elements with first element being a string"
)
def to_fake_obj(fake_mode, x: torch.ScriptObject) -> FakeScriptObject:
fake_x = _fake_obj_from_real(fake_mode, x)
import torch.utils._pytree as pytree
flat_x = x.__obj_flatten__() # type: ignore[attr-defined]
_check_valid_flat_script_obj(flat_x)
fake_flattened = pytree.tree_map_only(
torch.Tensor,
lambda t: fake_mode.from_tensor(t),
flat_x,
)
fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened)
def _call_torchbind(method_name):
from torch._higher_order_ops.torchbind import call_torchbind
@ -134,14 +161,13 @@ def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal]
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
class FakeTensorQueue:
def __init__(self, q):
self.queue = q
def __init__(self, queue):
self.queue = queue
@classmethod
def from_real(cls, real_tq):
ctx = torch.library.get_ctx()
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.clone_queue()]
return cls(fake_queue)
def __obj_unflatten__(cls, flattened_ctx):
ctx = {flattened_ctx[0]: flattened_ctx[1]}
return cls(**ctx)
def push(self, x):
self.queue.append(x)
@ -162,7 +188,9 @@ def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal]
from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
if not from_method:
raise RuntimeError(f"{fake_class} doesn't define a classmethod from_real.")
raise RuntimeError(
f"{fake_class} doesn't define a classmethod {_CONVERT_FROM_REAL_NAME}."
)
if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod):
raise RuntimeError(
@ -221,7 +249,7 @@ def _find_fake_class_for_script_object(x: torch.ScriptObject) -> Any:
return fake_class
_CONVERT_FROM_REAL_NAME = "from_real"
_CONVERT_FROM_REAL_NAME = "__obj_unflatten__"
def _fake_obj_from_real(fake_mode, x) -> Any:

View File

@ -58,6 +58,7 @@ hop_that_doesnt_have_opinfo_test_allowlist = [
"with_effects",
"strict_mode",
"_export_tracepoint",
"call_torchbind",
]
torch.library.define(

View File

@ -1,22 +1,35 @@
import contextlib
from typing import Optional
import torch
_TORCHBIND_IMPLS_INITIALIZED = False
_TENSOR_QUEUE_GLOBAL_TEST: Optional[torch.ScriptObject] = None
def init_torchbind_implementations():
global _TORCHBIND_IMPLS_INITIALIZED
global _TENSOR_QUEUE_GLOBAL_TEST
if _TORCHBIND_IMPLS_INITIALIZED:
return
load_torchbind_test_lib()
register_fake_operators()
register_fake_classes()
_TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue()
_TORCHBIND_IMPLS_INITIALIZED = True
def _empty_tensor_queue() -> torch.ScriptObject:
return torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
# put these under a function because the corresponding library might not be loaded yet.
def register_fake_operators():
@torch.library.register_fake("_TorchScriptTesting::takes_foo_python_meta")
@ -67,25 +80,23 @@ def register_fake_classes():
self.y = y
@classmethod
def from_real(cls, foo):
(x, y), _ = foo.__getstate__()
return cls(x, y)
def __obj_unflatten__(cls, flattend_foo):
return cls(**dict(flattend_foo))
def add_tensor(self, z):
return (self.x + self.y) * z
@torch._library.register_fake_class("_TorchScriptTesting::_ContainsTensor")
class FakeContainsTensor:
def __init__(self, x: torch.Tensor):
self.x = x
def __init__(self, t: torch.Tensor):
self.t = t
@classmethod
def from_real(cls, foo):
ctx = torch.library.get_ctx()
return cls(ctx.to_fake_tensor(foo.get()))
def __obj_unflatten__(cls, flattend_foo):
return cls(**dict(flattend_foo))
def get(self):
return self.x
return self.t
def load_torchbind_test_lib():