diff --git a/test/cpp/jit/test_custom_class_registrations.cpp b/test/cpp/jit/test_custom_class_registrations.cpp index f81952ea9cf..e92c397a034 100644 --- a/test/cpp/jit/test_custom_class_registrations.cpp +++ b/test/cpp/jit/test_custom_class_registrations.cpp @@ -194,6 +194,10 @@ struct TensorQueue : torch::CustomClassHolder { } return ret; } + std::vector get_raw_queue() { + std::vector raw_queue(queue_.begin(), queue_.end()); + return raw_queue; + } private: std::deque queue_; @@ -563,6 +567,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) { .def("top", &TensorQueue::top) .def("size", &TensorQueue::size) .def("clone_queue", &TensorQueue::clone_queue) + .def("get_raw_queue", &TensorQueue::get_raw_queue) .def_pickle( // __getstate__ [](const c10::intrusive_ptr& self) diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 8ba480d305c..0dbf088667c 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -1,21 +1,107 @@ # Owner(s): ["oncall: export"] +import unittest + import torch -import torch.testing._internal.torchbind_impls # noqa: F401 +import torch.utils._pytree as pytree from torch._higher_order_ops.torchbind import enable_torchbind_tracing +from torch._library.fake_class_registry import FakeScriptObject from torch.export import export from torch.export._trace import _export +from torch.fx.experimental.proxy_tensor import make_fx from torch.testing._internal.common_utils import ( + find_library_location, instantiate_parametrized_tests, + IS_FBCODE, + IS_MACOS, + IS_SANDCASTLE, + IS_WINDOWS, parametrize, run_tests, skipIfTorchDynamo, TestCase, ) +from torch.testing._internal.torchbind_impls import register_fake_operators + + +def load_torchbind_test_lib(): + if IS_SANDCASTLE or IS_FBCODE: + torch.ops.load_library("//caffe2/test/cpp/jit:test_custom_class_registrations") + elif IS_MACOS: + raise unittest.SkipTest("non-portable load_library call used in test") + else: + lib_file_path = find_library_location("libtorchbind_test.so") + if IS_WINDOWS: + lib_file_path = find_library_location("torchbind_test.dll") + torch.ops.load_library(str(lib_file_path)) + + register_fake_operators() @skipIfTorchDynamo("torchbind not supported with dynamo yet") class TestExportTorchbind(TestCase): + def setUp(self): + load_torchbind_test_lib() + + @torch._library.register_fake_class("_TorchScriptTesting::_Foo") + class FakeFoo: + def __init__(self, x: int, y: int): + self.x = x + self.y = y + + @classmethod + def from_real(cls, foo): + (x, y), _ = foo.__getstate__() + return cls(x, y) + + def add_tensor(self, z): + return (self.x + self.y) * z + + test = self + test.tq_push_counter = 0 + test.tq_pop_counter = 0 + test.tq_size_counter = 0 + + @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") + class FakeTensorQueue: + def __init__(self, q): + self.queue = q + + @classmethod + def from_real(cls, real_tq): + ctx = torch.library.get_ctx() + fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()] + return cls(fake_queue) + + def push(self, x): + test.tq_push_counter += 1 + self.queue.append(x) + + def pop(self): + test.tq_pop_counter += 1 + return self.queue.pop(0) + + def size(self): + test.tq_size_counter += 1 + return len(self.queue) + + def tearDown(self): + torch._library.fake_class_registry.deregister_fake_class( + "_TorchScriptTesting::_Foo" + ) + torch._library.fake_class_registry.deregister_fake_class( + "_TorchScriptTesting::_TensorQueue" + ) + + def _assertEqualSkipScriptObject(self, exp, actual): + flat_exp = pytree.tree_leaves(exp) + flat_actual = pytree.tree_leaves(actual) + self.assertEqual(len(flat_exp), len(flat_actual)) + for a, b in zip(flat_exp, flat_actual): + if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject): + continue + self.assertEqual(a, b) + def _test_export_same_as_eager( self, f, args, kwargs=None, strict=True, pre_dispatch=False ): @@ -351,6 +437,189 @@ def forward(self, arg0_1, attr, arg1_1): return (getitem_3, add_1)""", # noqa: B950 ) + @parametrize("make_fx_tracing_mode", ["fake", "symbolic"]) + def test_make_fx_tensor_queue_methods(self, make_fx_tracing_mode): + test = self + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 2) + self.check_tq_is_fake = True + + def forward(self, tq, x): + if self.check_tq_is_fake: + test.assertTrue(isinstance(tq, FakeScriptObject)) + tq.push(x.cos()) + tq.push(x.sin()) + x_cos = tq.pop() + tq.size() + x_sin = tq.pop() - tq.size() + return x_sin, x_cos, tq + + mod = Model() + tq = torch.classes._TorchScriptTesting._TensorQueue( + torch.empty( + 0, + ).fill_(-1) + ) + tq1 = torch.classes._TorchScriptTesting._TensorQueue( + torch.empty( + 0, + ).fill_(-1) + ) + x = torch.ones(2, 3) + gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x) + self.assertEqual(self.tq_push_counter, 2) + self.assertEqual(self.tq_pop_counter, 2) + self.assertEqual(self.tq_size_counter, 2) + self.assertEqual(tq.size(), 0) + self.assertExpectedInline( + gm.code.strip("\n"), + """\ +def forward(self, arg0_1, arg1_1): + cos = torch.ops.aten.cos.default(arg1_1) + call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'push', cos); cos = None + sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None + call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'push', sin); sin = None + call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop') + call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size') + add = torch.ops.aten.add.Tensor(call_torchbind_2, 1); call_torchbind_2 = None + call_torchbind_4 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop') + call_torchbind_5 = torch.ops.higher_order.call_torchbind(arg0_1, 'size') + sub = torch.ops.aten.sub.Tensor(call_torchbind_4, 0); call_torchbind_4 = None + return (sub, add, arg0_1) + """, + ) + mod.check_tq_is_fake = False + self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x)) + + @parametrize("make_fx_tracing_mode", ["fake", "symbolic"]) + def test_make_fx_tensor_queue_methods_fakify_internal_states( + self, make_fx_tracing_mode + ): + test = self + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 2) + self.check_tq_is_fake = True + self.current_test = test + + def forward(self, tq, x): + if self.check_tq_is_fake: + self.current_test.assertTrue(isinstance(tq, FakeScriptObject)) + x_cos = tq.pop() + tq.size() + x + x_sin = tq.pop() - tq.size() + x + return x_sin, x_cos, tq + + mod = Model() + tq = torch.classes._TorchScriptTesting._TensorQueue( + torch.empty( + 0, + ).fill_(-1) + ) + tq1 = torch.classes._TorchScriptTesting._TensorQueue( + torch.empty( + 0, + ).fill_(-1) + ) + for _ in range(2): + tq.push(torch.ones(2, 3)) + tq1.push(torch.ones(2, 3)) + x = torch.ones(2, 3) + prev_size = tq.size() + gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x) + self.assertEqual(self.tq_push_counter, 0) + self.assertEqual(self.tq_pop_counter, 2) + self.assertEqual(self.tq_size_counter, 2) + self.assertEqual(tq.size(), prev_size) + self.assertExpectedInline( + gm.code.strip("\n"), + """\ +def forward(self, arg0_1, arg1_1): + call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'pop') + call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'size') + add = torch.ops.aten.add.Tensor(call_torchbind, 1); call_torchbind = None + add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = None + call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop') + call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size') + sub = torch.ops.aten.sub.Tensor(call_torchbind_2, 0); call_torchbind_2 = None + add_2 = torch.ops.aten.add.Tensor(sub, arg1_1); sub = arg1_1 = None + return (add_2, add_1, arg0_1) + """, + ) + # turn off tq type checking in eager execution + mod.check_tq_is_fake = False + self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x)) + self.assertEqual(tq.size(), 0) + self.assertEqual(tq1.size(), 0) + + +@skipIfTorchDynamo("torchbind not supported with dynamo yet") +class TestRegisterFakeClass(TestCase): + def setUp(self): + load_torchbind_test_lib() + + def tearDown(self): + torch._library.fake_class_registry.global_fake_class_registry.clear() + + def test_register_fake_class_no_torch_bind_class(self): + with self.assertRaisesRegex(RuntimeError, "Tried to instantiate class"): + + @torch._library.register_fake_class("_TorchScriptTesting::NOT_A_VALID_NAME") + class Invalid: + pass + + def test_register_fake_class_no_from_real(self): + with self.assertRaisesRegex(RuntimeError, "define a classmethod from_real"): + + @torch._library.register_fake_class("_TorchScriptTesting::_Foo") + class InvalidFakeFoo: + def __init__(self): + pass + + def test_register_fake_class_from_real_not_classmethod(self): + with self.assertRaisesRegex(RuntimeError, "is not a classmethod"): + + @torch._library.register_fake_class("_TorchScriptTesting::_Foo") + class FakeFoo: + def __init__(self, x, y): + self.x = x + self.y = y + + def from_real(self, foo_obj): + x, y = foo_obj.__getstate__() + return FakeFoo(x, y) + + def test_register_fake_class_valid(self): + class FakeFoo: + def __init__(self, x, y): + self.x = x + self.y = y + + @classmethod + def from_real(cls, foo_obj): + x, y = foo_obj.__getstate__() + return cls(x, y) + + torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo) + + def test_register_fake_class_duplicate_registration(self): + @torch._library.register_fake_class("_TorchScriptTesting::_Foo") + class FakeFoo: + def __init__(self, x, y): + self.x = x + self.y = y + + @classmethod + def from_real(cls, foo_obj): + x, y = foo_obj.__getstate__() + return cls(x, y) + + with self.assertWarnsRegex(UserWarning, "already registered"): + torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo) + instantiate_parametrized_tests(TestExportTorchbind) diff --git a/torch/_higher_order_ops/torchbind.py b/torch/_higher_order_ops/torchbind.py index 25b99f20961..235dfe6ec41 100644 --- a/torch/_higher_order_ops/torchbind.py +++ b/torch/_higher_order_ops/torchbind.py @@ -1,15 +1,19 @@ +import logging from contextlib import contextmanager import torch from torch._C import DispatchKey # @manual from torch._functorch._aot_autograd.utils import KNOWN_TYPES from torch._higher_order_ops.utils import autograd_not_implemented +from torch._library.fake_class_registry import _ns_and_class_name, FakeScriptObject from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree from torch.fx.node import has_side_effect from torch.utils import _pytree as pytree +log = logging.getLogger(__name__) + # The call_torchbind operator represents a method invocation on a torchbind # object. The calling convention is: # call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs) @@ -52,7 +56,12 @@ def enable_torchbind_tracing(): @call_torchbind.py_impl(DispatchKey.CompositeExplicitAutograd) def call_torchbind_impl(obj, method, *args, **kwargs): - return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs) + if isinstance(obj, torch.ScriptObject): + return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs) + elif isinstance(obj, FakeScriptObject): + return getattr(obj.wrapped_obj, method)(*args, **kwargs) + else: + raise RuntimeError(f"Unsupported first arg type {type(obj)} for call_torchbind") @call_torchbind.py_impl(ProxyTorchDispatchMode) @@ -69,6 +78,21 @@ def inner(mode, *args, **kwargs): ) out = call_torchbind(*args, **kwargs) + obj, method, *rest_args = args + if isinstance(obj, torch.ScriptObject): + ns, class_name = _ns_and_class_name( + obj._type().qualified_name() # type: ignore[attr-defined] + ) + log.warning( + "Tracing torchbind method %s.%s with real ScriptObject. This may" + " cause the original object being mutated. If this is not intended," + ' You can register a fake class with torch._library.register_fake_class("%s::%s").', + class_name, + method, + ns, + class_name, + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) else: return call_torchbind(*args, **kwargs) diff --git a/torch/_library/__init__.py b/torch/_library/__init__.py index 8638a01bdf4..efb316c33b8 100644 --- a/torch/_library/__init__.py +++ b/torch/_library/__init__.py @@ -1,3 +1,5 @@ import torch._library.abstract_impl import torch._library.simple_registry import torch._library.utils + +from torch._library.fake_class_registry import register_fake_class diff --git a/torch/_library/abstract_impl.py b/torch/_library/abstract_impl.py index e09d3eace9b..f5eb7e3578c 100644 --- a/torch/_library/abstract_impl.py +++ b/torch/_library/abstract_impl.py @@ -61,13 +61,13 @@ class AbstractImplHolder: meta_kernel = construct_meta_kernel(self.qualname, self) self.lib.impl(self.qualname, meta_kernel, "Meta") - def deregister_abstract_impl(): + def deregister_fake_class(): if self.lib: self.lib._destroy() self.lib = None self.kernel = None - return RegistrationHandle(deregister_abstract_impl) + return RegistrationHandle(deregister_fake_class) def construct_meta_kernel( @@ -119,8 +119,9 @@ class AbstractImplCtx: Context object for writing abstract implementations for custom operators. """ - def __init__(self, _shape_env, _op): - self._shape_env = _shape_env + def __init__(self, _fake_mode, _op): + self._fake_mode = _fake_mode + self._shape_env = _fake_mode.shape_env self._op = _op def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt: @@ -204,3 +205,38 @@ class AbstractImplCtx: result, min=min, max=max ) return result + + def to_fake_tensor(self, tensor: torch.Tensor): + """ + Creates a fake tensor from a concrete tensor. Note: this is not needed for impl_abstract. + + This is useful for register_fake_class (which is necessary for torch.compile) for custom class. + Users need to implement a from_real method that takes a real custom object and creates a fake + custom object. Users can use this API to create fake tensors for the tensor states in the custom object. + + Args: + tensor (torch.Tensor): A concrete tensor. + + Example:: + >>> import torch + >>> @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") # xdoctest: +SKIP + ... class FakeTensorQueue: + ... def __init__(self, q): + ... self.queue = q + ... + ... @classmethod + ... def from_real(cls, real_tq): + ... ctx = torch.library.get_ctx() + ... fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()] + ... return cls(fake_queue) + ... + ... def push(self, x): + ... self.queue.append(x) + ... + ... def pop(self): + ... return self.queue.pop(0) + ... + ... def size(self): + ... return len(self.queue) + """ + return self._fake_mode.from_tensor(tensor) diff --git a/torch/_library/fake_class_registry.py b/torch/_library/fake_class_registry.py new file mode 100644 index 00000000000..7eff7562841 --- /dev/null +++ b/torch/_library/fake_class_registry.py @@ -0,0 +1,234 @@ +import logging +import warnings +from typing import Any, Dict, Optional, Protocol, Tuple + +import torch + +from torch._library.utils import parse_namespace + +log = logging.getLogger(__name__) + + +class FakeScriptObject: + def __init__(self, wrapped_obj): + self.wrapped_obj = wrapped_obj + + +class HasStaticMethodFromReal(Protocol): + @classmethod + def from_real(cls, real_obj: torch.ScriptObject): + pass + + +class FakeClassRegistry: + def __init__(self): + self._registered_class: Dict[str, Any] = {} + + def has_impl(self, full_qualname: str) -> bool: + return full_qualname in self._registered_class + + def get_impl(self, full_qualname: str) -> Any: + self._check_registered(full_qualname) + return self._registered_class[full_qualname] + + def register(self, full_qualname: str, fake_class=None) -> None: + if self.has_impl(full_qualname): + warnings.warn( + f"{full_qualname} is already registered. Previous fake class is overrided with {fake_class}." + ) + self._registered_class[full_qualname] = fake_class + + def deregister(self, full_qualname: str) -> Any: + if not self.has_impl(full_qualname): + raise RuntimeError( + f"Cannot deregister {full_qualname}. Please use register_fake_class to register it first." + f" Or do you dereigster it twice?" + ) + self._check_registered(full_qualname) + return self._registered_class.pop(full_qualname) + + def clear(self) -> None: + self._registered_class.clear() + + def _check_registered(self, full_qualname: str) -> None: + if full_qualname not in self._registered_class: + raise RuntimeError( + f"{full_qualname} is not registered. Please use register_fake_class to register it first." + ) + + +global_fake_class_registry = FakeClassRegistry() + + +def to_fake_obj(fake_mode, x: torch.ScriptObject) -> FakeScriptObject: + fake_x = _fake_obj_from_real(fake_mode, x) + + def _call_torchbind(method_name): + from torch._higher_order_ops.torchbind import call_torchbind + + def wrapped(self_, *args, **kwargs): + return call_torchbind(self_, method_name, *args, **kwargs) + + return wrapped + + fake_x_wrapped = FakeScriptObject(fake_x) + for name in x._method_names(): # type: ignore[attr-defined] + attr = getattr(fake_x, name, None) + if attr: + if not callable(attr): + raise RuntimeError(f"Expect {name} to be a callable but got {attr}.") + setattr( + fake_x_wrapped, + name, + _call_torchbind(name).__get__(fake_x_wrapped), + ) + else: + log.warning("fake object of %s doesn't implement method %s.", x, name) + return fake_x_wrapped + + +def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal] = None): + r"""Register a fake implementation for this class. + + It's in the same spirit of registering a fake implementation for + an operator but with the difference that it + associates a fake class with the original torch bind class (registered + with torch::class_). In this way, torch.compile can handle them properly + in components such as Dynamo and AOTAutograd. + + This API may be used as a decorator (see example). For the fake class, users + are required to provide a from_real classmethod that takes a real object and + returns an instance of the fake class. All tensors in the fake object should also + be properly fakified with to_fake_tensor() in from_real. + + Examples: + # For a custom class Foo defined in test_custom_class_registration.cpp: + TORCH_LIBRARY(_TorchScriptTesting, m) { + m.class_("_TensorQueue") + .def(torch::init()) + .def("push", &TensorQueue::push) + .def("pop", &TensorQueue::pop) + .def("top", &TensorQueue::top) + .def("size", &TensorQueue::size) + .def("clone_queue", &TensorQueue::clone_queue) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr& self) + -> c10::Dict { + return self->serialize(); + }, + // __setstate__ + [](c10::Dict data) + -> c10::intrusive_ptr { + return c10::make_intrusive(std::move(data)); + }); + }; + # We could register a fake class FakeTensorQueue in Python as follows: + import torch + + @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") + class FakeTensorQueue: + def __init__(self, q): + self.queue = q + + @classmethod + def from_real(cls, real_tq): + ctx = torch.library.get_ctx() + fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.clone_queue()] + return cls(fake_queue) + + def push(self, x): + self.queue.append(x) + + def pop(self): + return self.queue.pop(0) + + def size(self): + return len(self.queue) + + """ + + def inner(fake_class: HasStaticMethodFromReal): + ns, name = parse_namespace(qualname) + + # This also checks whether the refered torch::class_ exists. + torchbind_class = torch._C._get_custom_class_python_wrapper(ns, name) + + from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None) + if not from_method: + raise RuntimeError(f"{fake_class} doesn't define a classmethod from_real.") + + if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod): + raise RuntimeError( + f"{_CONVERT_FROM_REAL_NAME} method is not a classmethod." + ) + + global_fake_class_registry.register(_full_qual_class_name(qualname), fake_class) + return fake_class + + if fake_class is None: + return inner + return inner(fake_class) + + +def deregister_fake_class(qualname): + return global_fake_class_registry.deregister(_full_qual_class_name(qualname)) + + +def has_fake_class(full_qualname) -> bool: + return global_fake_class_registry.has_impl(full_qualname) + + +def find_fake_class(full_qualname) -> Optional[Any]: + if not has_fake_class(full_qualname): + return None + return global_fake_class_registry.get_impl(full_qualname) + + +def _full_qual_class_name(qualname: str) -> str: + ns, name = parse_namespace(qualname) + return "__torch__.torch.classes." + ns + "." + name + + +# Return the namespace and class name from fully qualified name. +def _ns_and_class_name(full_qualname: str) -> Tuple[str, str]: + splits = full_qualname.split(".") + assert len(splits) == 5 + _torch, torch_ns, classes, ns, class_name = splits + return ns, class_name + + +def _find_fake_class_for_script_object(x: torch.ScriptObject) -> Any: + full_qualname = x._type().qualified_name() # type: ignore[attr-defined] + ns, class_name = _ns_and_class_name(full_qualname) + fake_class = find_fake_class(full_qualname) + if fake_class is None: + raise RuntimeError( + f" ScriptObject's {full_qualname} haven't registered a fake class." + f" Please use register_fake_class({ns}::{class_name}) to annotate a fake class for the script obj." + f" Specifically, create a python class that implements a fake version for all the methods" + f" that're used in the program and put annotated class in the program e.g. after loading the library." + f" The fake methods can be written in the same way as a meta kernel for an operator but need to additionally" + f" simulate the object's states. Be sure to add a {_CONVERT_FROM_REAL_NAME} classmethod" + f" to enable creating a fake obj from a real one." + ) + return fake_class + + +_CONVERT_FROM_REAL_NAME = "from_real" + + +def _fake_obj_from_real(fake_mode, x) -> Any: + fake_class = _find_fake_class_for_script_object(x) + + from_real_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None) + if not from_real_method: + raise RuntimeError( + f"{fake_class} must define a classmethod {_CONVERT_FROM_REAL_NAME}" + f" that converts the real object to the fake object." + ) + + # from_real defined by user need the ctx to fakify the tensor states. + ctx = torch._library.abstract_impl.AbstractImplCtx(fake_mode, None) + with torch._library.abstract_impl.set_ctx_getter(lambda: ctx): + return fake_class.from_real(x) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 8ec95747b19..c5a8e27d946 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1359,7 +1359,7 @@ class FakeTensorMode(TorchDispatchMode): func.name() ).abstract_impl.kernel if maybe_abstract_impl: - ctx = torch._library.abstract_impl.AbstractImplCtx(self.shape_env, func) + ctx = torch._library.abstract_impl.AbstractImplCtx(self, func) with torch._library.abstract_impl.set_ctx_getter(lambda: ctx), self: result = maybe_abstract_impl(*args, **kwargs) return result diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index a5ff0ae8812..e25fda855b8 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -26,6 +26,7 @@ import weakref import operator from torch.utils._stats import count import logging +from torch._library.fake_class_registry import FakeScriptObject from torch.overrides import TorchFunctionMode @@ -97,7 +98,7 @@ def set_proxy_slot(obj, tracer, proxy): # We DO want to clobber proxies whenever we run an inplace operation # on a tensor, and it affects the metadata on the proxy. tracer.tensor_tracker[obj] = proxy - elif isinstance(obj, torch.ScriptObject): + elif isinstance(obj, (torch.ScriptObject, FakeScriptObject)): # We DO want to clobber proxies, with a similar rationale as for tensors. tracer.script_object_tracker[obj] = proxy else: @@ -121,7 +122,7 @@ def has_proxy_slot(obj, tracer): def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x): if isinstance(obj, torch.Tensor): tracker = tracer.tensor_tracker - elif isinstance(obj, torch.ScriptObject): + elif isinstance(obj, (torch.ScriptObject, FakeScriptObject)): tracker = tracer.script_object_tracker else: assert isinstance(obj, py_sym_types), type(obj) @@ -141,7 +142,7 @@ def extract_val(val): return snapshot_fake(val) elif isinstance(val, py_sym_types): return val - elif isinstance(val, torch.ScriptObject): + elif isinstance(val, (torch.ScriptObject, FakeScriptObject)): return val elif isinstance(val, BackwardState): return val @@ -218,7 +219,7 @@ def track_tensor_tree(inner_res, proxy_res, *, constant, tracer): # NB: eagerly set meta here, so that the numbering is in order set_meta(proxy, e) set_proxy_slot(e, tracer, lambda: proxy) - elif isinstance(e, torch.ScriptObject): + elif isinstance(e, (torch.ScriptObject, FakeScriptObject)): set_proxy_slot(e, tracer, proxy) set_meta(proxy, e) elif isinstance(e, (tuple, list)): @@ -337,7 +338,7 @@ def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs): f_flat_args_kwargs = [ ( fetch_object_proxy(tracer)(x) - if isinstance(x, (torch.Tensor, torch.ScriptObject)) + if isinstance(x, (torch.Tensor, torch.ScriptObject, FakeScriptObject)) else x ) for x in flat_args_kwargs @@ -591,7 +592,7 @@ class PythonKeyTracer(Tracer): return get_proxy_slot(e, self, e, lambda e: e.proxy) elif isinstance(e, (torch.SymInt, torch.SymFloat, torch.SymBool)): return get_proxy_slot(e, self, e, lambda e: e()) - elif isinstance(e, torch.ScriptObject): + elif isinstance(e, (torch.ScriptObject, FakeScriptObject)): return get_proxy_slot(e, self, e) else: return e @@ -659,6 +660,11 @@ def wrap_key(f, tensors, tracer, pre_dispatch: bool): lambda t: get_proxy_slot(t, tracer, t, lambda x: x.proxy), out ) + out = pytree.tree_map_only( + (torch.ScriptObject, FakeScriptObject), + lambda t: get_proxy_slot(t, tracer, t, lambda x: x), + out + ) out = pytree.tree_map_only( (SymInt, SymFloat, SymBool), lambda t: get_proxy_slot(t, tracer)(), @@ -1182,7 +1188,10 @@ def make_fx(f, # NB: don't match on bools elif type(x) is int and tracing_mode == "symbolic": return shape_env.create_symintnode(shape_env.create_symbol(x, source, positive=None), hint=x, source=source) + elif isinstance(x, torch.ScriptObject): + return torch._library.fake_class_registry.to_fake_obj(fake_tensor_mode, x) + assert not isinstance(x, FakeScriptObject), f"ScriptObject {x} has been fakified. Cannot wrap_fake it again." return x sym_mode = proxy_mode.sym_mode diff --git a/torch/testing/_internal/torchbind_impls.py b/torch/testing/_internal/torchbind_impls.py index b90822ce447..933de441059 100644 --- a/torch/testing/_internal/torchbind_impls.py +++ b/torch/testing/_internal/torchbind_impls.py @@ -1,29 +1,32 @@ import torch -from torch.testing._internal.common_utils import find_library_location, IS_WINDOWS - -if IS_WINDOWS: - lib_file_path = find_library_location("torchbind_test.dll") - torch.ops.load_library(str(lib_file_path)) -else: - lib_file_path = find_library_location("libtorchbind_test.so") - torch.ops.load_library(str(lib_file_path)) -@torch.library.impl_abstract("_TorchScriptTesting::takes_foo_python_meta") -def fake_takes_foo(foo, z): - return foo.add_tensor(z) +def register_if_not(qualname): + entry = torch._library.simple_registry.singleton.find(qualname) + if entry.abstract_impl.kernel is not None: + return torch.library.impl_abstract(qualname) + else: + + def dummy_wrapper(fn): + return fn + + return dummy_wrapper -@torch.library.impl_abstract("_TorchScriptTesting::queue_pop") -def fake_queue_pop(tq): - return tq.pop() +# put these under a function because the corresponding library might not be loaded yet. +def register_fake_operators(): + @register_if_not("_TorchScriptTesting::takes_foo_python_meta") + def fake_takes_foo(foo, z): + return foo.add_tensor(z) + @register_if_not("_TorchScriptTesting::queue_pop") + def fake_queue_pop(tq): + return tq.pop() -@torch.library.impl_abstract("_TorchScriptTesting::queue_push") -def fake_queue_push(tq, x): - return tq.push(x) + @register_if_not("_TorchScriptTesting::queue_push") + def fake_queue_push(tq, x): + return tq.push(x) - -@torch.library.impl_abstract("_TorchScriptTesting::queue_size") -def fake_queue_size(tq, x): - return tq.size() + @register_if_not("_TorchScriptTesting::queue_size") + def fake_queue_size(tq, x): + return tq.size()