Add torch._library.register_fake_class to fakify torchBind class (#122622)

This PR only adds abstract class registration logic without touching existing tests so they still trace with real script object. The added tests are only for registration APIs and test error messages.

Our design is that the abstract implementation should be in Python. This is much better in terms of usability. But this also has implications for custom op that takes script object as input, which is detailed later in this stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122622
Approved by: https://github.com/zou3519
ghstack dependencies: #122619, #122620, #122621
This commit is contained in:
ydwu4 2024-04-02 11:54:30 -07:00 committed by PyTorch MergeBot
parent 46c7235406
commit c77352b5cc
9 changed files with 616 additions and 34 deletions

View File

@ -194,6 +194,10 @@ struct TensorQueue : torch::CustomClassHolder {
}
return ret;
}
std::vector<at::Tensor> get_raw_queue() {
std::vector<at::Tensor> raw_queue(queue_.begin(), queue_.end());
return raw_queue;
}
private:
std::deque<at::Tensor> queue_;
@ -563,6 +567,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
.def("top", &TensorQueue::top)
.def("size", &TensorQueue::size)
.def("clone_queue", &TensorQueue::clone_queue)
.def("get_raw_queue", &TensorQueue::get_raw_queue)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<TensorQueue>& self)

View File

@ -1,21 +1,107 @@
# Owner(s): ["oncall: export"]
import unittest
import torch
import torch.testing._internal.torchbind_impls # noqa: F401
import torch.utils._pytree as pytree
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch._library.fake_class_registry import FakeScriptObject
from torch.export import export
from torch.export._trace import _export
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
find_library_location,
instantiate_parametrized_tests,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
parametrize,
run_tests,
skipIfTorchDynamo,
TestCase,
)
from torch.testing._internal.torchbind_impls import register_fake_operators
def load_torchbind_test_lib():
if IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library("//caffe2/test/cpp/jit:test_custom_class_registrations")
elif IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test")
else:
lib_file_path = find_library_location("libtorchbind_test.so")
if IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
register_fake_operators()
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
class TestExportTorchbind(TestCase):
def setUp(self):
load_torchbind_test_lib()
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class FakeFoo:
def __init__(self, x: int, y: int):
self.x = x
self.y = y
@classmethod
def from_real(cls, foo):
(x, y), _ = foo.__getstate__()
return cls(x, y)
def add_tensor(self, z):
return (self.x + self.y) * z
test = self
test.tq_push_counter = 0
test.tq_pop_counter = 0
test.tq_size_counter = 0
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
class FakeTensorQueue:
def __init__(self, q):
self.queue = q
@classmethod
def from_real(cls, real_tq):
ctx = torch.library.get_ctx()
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
return cls(fake_queue)
def push(self, x):
test.tq_push_counter += 1
self.queue.append(x)
def pop(self):
test.tq_pop_counter += 1
return self.queue.pop(0)
def size(self):
test.tq_size_counter += 1
return len(self.queue)
def tearDown(self):
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_Foo"
)
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_TensorQueue"
)
def _assertEqualSkipScriptObject(self, exp, actual):
flat_exp = pytree.tree_leaves(exp)
flat_actual = pytree.tree_leaves(actual)
self.assertEqual(len(flat_exp), len(flat_actual))
for a, b in zip(flat_exp, flat_actual):
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
continue
self.assertEqual(a, b)
def _test_export_same_as_eager(
self, f, args, kwargs=None, strict=True, pre_dispatch=False
):
@ -351,6 +437,189 @@ def forward(self, arg0_1, attr, arg1_1):
return (getitem_3, add_1)""", # noqa: B950
)
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
def test_make_fx_tensor_queue_methods(self, make_fx_tracing_mode):
test = self
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 2)
self.check_tq_is_fake = True
def forward(self, tq, x):
if self.check_tq_is_fake:
test.assertTrue(isinstance(tq, FakeScriptObject))
tq.push(x.cos())
tq.push(x.sin())
x_cos = tq.pop() + tq.size()
x_sin = tq.pop() - tq.size()
return x_sin, x_cos, tq
mod = Model()
tq = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
x = torch.ones(2, 3)
gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
self.assertEqual(self.tq_push_counter, 2)
self.assertEqual(self.tq_pop_counter, 2)
self.assertEqual(self.tq_size_counter, 2)
self.assertEqual(tq.size(), 0)
self.assertExpectedInline(
gm.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1):
cos = torch.ops.aten.cos.default(arg1_1)
call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'push', cos); cos = None
sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'push', sin); sin = None
call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
add = torch.ops.aten.add.Tensor(call_torchbind_2, 1); call_torchbind_2 = None
call_torchbind_4 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_5 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
sub = torch.ops.aten.sub.Tensor(call_torchbind_4, 0); call_torchbind_4 = None
return (sub, add, arg0_1)
""",
)
mod.check_tq_is_fake = False
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
def test_make_fx_tensor_queue_methods_fakify_internal_states(
self, make_fx_tracing_mode
):
test = self
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 2)
self.check_tq_is_fake = True
self.current_test = test
def forward(self, tq, x):
if self.check_tq_is_fake:
self.current_test.assertTrue(isinstance(tq, FakeScriptObject))
x_cos = tq.pop() + tq.size() + x
x_sin = tq.pop() - tq.size() + x
return x_sin, x_cos, tq
mod = Model()
tq = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
for _ in range(2):
tq.push(torch.ones(2, 3))
tq1.push(torch.ones(2, 3))
x = torch.ones(2, 3)
prev_size = tq.size()
gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
self.assertEqual(self.tq_push_counter, 0)
self.assertEqual(self.tq_pop_counter, 2)
self.assertEqual(self.tq_size_counter, 2)
self.assertEqual(tq.size(), prev_size)
self.assertExpectedInline(
gm.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1):
call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
add = torch.ops.aten.add.Tensor(call_torchbind, 1); call_torchbind = None
add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = None
call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
sub = torch.ops.aten.sub.Tensor(call_torchbind_2, 0); call_torchbind_2 = None
add_2 = torch.ops.aten.add.Tensor(sub, arg1_1); sub = arg1_1 = None
return (add_2, add_1, arg0_1)
""",
)
# turn off tq type checking in eager execution
mod.check_tq_is_fake = False
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
self.assertEqual(tq.size(), 0)
self.assertEqual(tq1.size(), 0)
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
class TestRegisterFakeClass(TestCase):
def setUp(self):
load_torchbind_test_lib()
def tearDown(self):
torch._library.fake_class_registry.global_fake_class_registry.clear()
def test_register_fake_class_no_torch_bind_class(self):
with self.assertRaisesRegex(RuntimeError, "Tried to instantiate class"):
@torch._library.register_fake_class("_TorchScriptTesting::NOT_A_VALID_NAME")
class Invalid:
pass
def test_register_fake_class_no_from_real(self):
with self.assertRaisesRegex(RuntimeError, "define a classmethod from_real"):
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class InvalidFakeFoo:
def __init__(self):
pass
def test_register_fake_class_from_real_not_classmethod(self):
with self.assertRaisesRegex(RuntimeError, "is not a classmethod"):
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class FakeFoo:
def __init__(self, x, y):
self.x = x
self.y = y
def from_real(self, foo_obj):
x, y = foo_obj.__getstate__()
return FakeFoo(x, y)
def test_register_fake_class_valid(self):
class FakeFoo:
def __init__(self, x, y):
self.x = x
self.y = y
@classmethod
def from_real(cls, foo_obj):
x, y = foo_obj.__getstate__()
return cls(x, y)
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
def test_register_fake_class_duplicate_registration(self):
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class FakeFoo:
def __init__(self, x, y):
self.x = x
self.y = y
@classmethod
def from_real(cls, foo_obj):
x, y = foo_obj.__getstate__()
return cls(x, y)
with self.assertWarnsRegex(UserWarning, "already registered"):
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
instantiate_parametrized_tests(TestExportTorchbind)

View File

@ -1,15 +1,19 @@
import logging
from contextlib import contextmanager
import torch
from torch._C import DispatchKey # @manual
from torch._functorch._aot_autograd.utils import KNOWN_TYPES
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._library.fake_class_registry import _ns_and_class_name, FakeScriptObject
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.fx.node import has_side_effect
from torch.utils import _pytree as pytree
log = logging.getLogger(__name__)
# The call_torchbind operator represents a method invocation on a torchbind
# object. The calling convention is:
# call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs)
@ -52,7 +56,12 @@ def enable_torchbind_tracing():
@call_torchbind.py_impl(DispatchKey.CompositeExplicitAutograd)
def call_torchbind_impl(obj, method, *args, **kwargs):
return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs)
if isinstance(obj, torch.ScriptObject):
return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs)
elif isinstance(obj, FakeScriptObject):
return getattr(obj.wrapped_obj, method)(*args, **kwargs)
else:
raise RuntimeError(f"Unsupported first arg type {type(obj)} for call_torchbind")
@call_torchbind.py_impl(ProxyTorchDispatchMode)
@ -69,6 +78,21 @@ def inner(mode, *args, **kwargs):
)
out = call_torchbind(*args, **kwargs)
obj, method, *rest_args = args
if isinstance(obj, torch.ScriptObject):
ns, class_name = _ns_and_class_name(
obj._type().qualified_name() # type: ignore[attr-defined]
)
log.warning(
"Tracing torchbind method %s.%s with real ScriptObject. This may"
" cause the original object being mutated. If this is not intended,"
' You can register a fake class with torch._library.register_fake_class("%s::%s").',
class_name,
method,
ns,
class_name,
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
else:
return call_torchbind(*args, **kwargs)

View File

@ -1,3 +1,5 @@
import torch._library.abstract_impl
import torch._library.simple_registry
import torch._library.utils
from torch._library.fake_class_registry import register_fake_class

View File

@ -61,13 +61,13 @@ class AbstractImplHolder:
meta_kernel = construct_meta_kernel(self.qualname, self)
self.lib.impl(self.qualname, meta_kernel, "Meta")
def deregister_abstract_impl():
def deregister_fake_class():
if self.lib:
self.lib._destroy()
self.lib = None
self.kernel = None
return RegistrationHandle(deregister_abstract_impl)
return RegistrationHandle(deregister_fake_class)
def construct_meta_kernel(
@ -119,8 +119,9 @@ class AbstractImplCtx:
Context object for writing abstract implementations for custom operators.
"""
def __init__(self, _shape_env, _op):
self._shape_env = _shape_env
def __init__(self, _fake_mode, _op):
self._fake_mode = _fake_mode
self._shape_env = _fake_mode.shape_env
self._op = _op
def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt:
@ -204,3 +205,38 @@ class AbstractImplCtx:
result, min=min, max=max
)
return result
def to_fake_tensor(self, tensor: torch.Tensor):
"""
Creates a fake tensor from a concrete tensor. Note: this is not needed for impl_abstract.
This is useful for register_fake_class (which is necessary for torch.compile) for custom class.
Users need to implement a from_real method that takes a real custom object and creates a fake
custom object. Users can use this API to create fake tensors for the tensor states in the custom object.
Args:
tensor (torch.Tensor): A concrete tensor.
Example::
>>> import torch
>>> @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") # xdoctest: +SKIP
... class FakeTensorQueue:
... def __init__(self, q):
... self.queue = q
...
... @classmethod
... def from_real(cls, real_tq):
... ctx = torch.library.get_ctx()
... fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
... return cls(fake_queue)
...
... def push(self, x):
... self.queue.append(x)
...
... def pop(self):
... return self.queue.pop(0)
...
... def size(self):
... return len(self.queue)
"""
return self._fake_mode.from_tensor(tensor)

View File

@ -0,0 +1,234 @@
import logging
import warnings
from typing import Any, Dict, Optional, Protocol, Tuple
import torch
from torch._library.utils import parse_namespace
log = logging.getLogger(__name__)
class FakeScriptObject:
def __init__(self, wrapped_obj):
self.wrapped_obj = wrapped_obj
class HasStaticMethodFromReal(Protocol):
@classmethod
def from_real(cls, real_obj: torch.ScriptObject):
pass
class FakeClassRegistry:
def __init__(self):
self._registered_class: Dict[str, Any] = {}
def has_impl(self, full_qualname: str) -> bool:
return full_qualname in self._registered_class
def get_impl(self, full_qualname: str) -> Any:
self._check_registered(full_qualname)
return self._registered_class[full_qualname]
def register(self, full_qualname: str, fake_class=None) -> None:
if self.has_impl(full_qualname):
warnings.warn(
f"{full_qualname} is already registered. Previous fake class is overrided with {fake_class}."
)
self._registered_class[full_qualname] = fake_class
def deregister(self, full_qualname: str) -> Any:
if not self.has_impl(full_qualname):
raise RuntimeError(
f"Cannot deregister {full_qualname}. Please use register_fake_class to register it first."
f" Or do you dereigster it twice?"
)
self._check_registered(full_qualname)
return self._registered_class.pop(full_qualname)
def clear(self) -> None:
self._registered_class.clear()
def _check_registered(self, full_qualname: str) -> None:
if full_qualname not in self._registered_class:
raise RuntimeError(
f"{full_qualname} is not registered. Please use register_fake_class to register it first."
)
global_fake_class_registry = FakeClassRegistry()
def to_fake_obj(fake_mode, x: torch.ScriptObject) -> FakeScriptObject:
fake_x = _fake_obj_from_real(fake_mode, x)
def _call_torchbind(method_name):
from torch._higher_order_ops.torchbind import call_torchbind
def wrapped(self_, *args, **kwargs):
return call_torchbind(self_, method_name, *args, **kwargs)
return wrapped
fake_x_wrapped = FakeScriptObject(fake_x)
for name in x._method_names(): # type: ignore[attr-defined]
attr = getattr(fake_x, name, None)
if attr:
if not callable(attr):
raise RuntimeError(f"Expect {name} to be a callable but got {attr}.")
setattr(
fake_x_wrapped,
name,
_call_torchbind(name).__get__(fake_x_wrapped),
)
else:
log.warning("fake object of %s doesn't implement method %s.", x, name)
return fake_x_wrapped
def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal] = None):
r"""Register a fake implementation for this class.
It's in the same spirit of registering a fake implementation for
an operator but with the difference that it
associates a fake class with the original torch bind class (registered
with torch::class_). In this way, torch.compile can handle them properly
in components such as Dynamo and AOTAutograd.
This API may be used as a decorator (see example). For the fake class, users
are required to provide a from_real classmethod that takes a real object and
returns an instance of the fake class. All tensors in the fake object should also
be properly fakified with to_fake_tensor() in from_real.
Examples:
# For a custom class Foo defined in test_custom_class_registration.cpp:
TORCH_LIBRARY(_TorchScriptTesting, m) {
m.class_<TensorQueue>("_TensorQueue")
.def(torch::init<at::Tensor>())
.def("push", &TensorQueue::push)
.def("pop", &TensorQueue::pop)
.def("top", &TensorQueue::top)
.def("size", &TensorQueue::size)
.def("clone_queue", &TensorQueue::clone_queue)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<TensorQueue>& self)
-> c10::Dict<std::string, at::Tensor> {
return self->serialize();
},
// __setstate__
[](c10::Dict<std::string, at::Tensor> data)
-> c10::intrusive_ptr<TensorQueue> {
return c10::make_intrusive<TensorQueue>(std::move(data));
});
};
# We could register a fake class FakeTensorQueue in Python as follows:
import torch
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
class FakeTensorQueue:
def __init__(self, q):
self.queue = q
@classmethod
def from_real(cls, real_tq):
ctx = torch.library.get_ctx()
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.clone_queue()]
return cls(fake_queue)
def push(self, x):
self.queue.append(x)
def pop(self):
return self.queue.pop(0)
def size(self):
return len(self.queue)
"""
def inner(fake_class: HasStaticMethodFromReal):
ns, name = parse_namespace(qualname)
# This also checks whether the refered torch::class_ exists.
torchbind_class = torch._C._get_custom_class_python_wrapper(ns, name)
from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
if not from_method:
raise RuntimeError(f"{fake_class} doesn't define a classmethod from_real.")
if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod):
raise RuntimeError(
f"{_CONVERT_FROM_REAL_NAME} method is not a classmethod."
)
global_fake_class_registry.register(_full_qual_class_name(qualname), fake_class)
return fake_class
if fake_class is None:
return inner
return inner(fake_class)
def deregister_fake_class(qualname):
return global_fake_class_registry.deregister(_full_qual_class_name(qualname))
def has_fake_class(full_qualname) -> bool:
return global_fake_class_registry.has_impl(full_qualname)
def find_fake_class(full_qualname) -> Optional[Any]:
if not has_fake_class(full_qualname):
return None
return global_fake_class_registry.get_impl(full_qualname)
def _full_qual_class_name(qualname: str) -> str:
ns, name = parse_namespace(qualname)
return "__torch__.torch.classes." + ns + "." + name
# Return the namespace and class name from fully qualified name.
def _ns_and_class_name(full_qualname: str) -> Tuple[str, str]:
splits = full_qualname.split(".")
assert len(splits) == 5
_torch, torch_ns, classes, ns, class_name = splits
return ns, class_name
def _find_fake_class_for_script_object(x: torch.ScriptObject) -> Any:
full_qualname = x._type().qualified_name() # type: ignore[attr-defined]
ns, class_name = _ns_and_class_name(full_qualname)
fake_class = find_fake_class(full_qualname)
if fake_class is None:
raise RuntimeError(
f" ScriptObject's {full_qualname} haven't registered a fake class."
f" Please use register_fake_class({ns}::{class_name}) to annotate a fake class for the script obj."
f" Specifically, create a python class that implements a fake version for all the methods"
f" that're used in the program and put annotated class in the program e.g. after loading the library."
f" The fake methods can be written in the same way as a meta kernel for an operator but need to additionally"
f" simulate the object's states. Be sure to add a {_CONVERT_FROM_REAL_NAME} classmethod"
f" to enable creating a fake obj from a real one."
)
return fake_class
_CONVERT_FROM_REAL_NAME = "from_real"
def _fake_obj_from_real(fake_mode, x) -> Any:
fake_class = _find_fake_class_for_script_object(x)
from_real_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
if not from_real_method:
raise RuntimeError(
f"{fake_class} must define a classmethod {_CONVERT_FROM_REAL_NAME}"
f" that converts the real object to the fake object."
)
# from_real defined by user need the ctx to fakify the tensor states.
ctx = torch._library.abstract_impl.AbstractImplCtx(fake_mode, None)
with torch._library.abstract_impl.set_ctx_getter(lambda: ctx):
return fake_class.from_real(x)

View File

@ -1359,7 +1359,7 @@ class FakeTensorMode(TorchDispatchMode):
func.name()
).abstract_impl.kernel
if maybe_abstract_impl:
ctx = torch._library.abstract_impl.AbstractImplCtx(self.shape_env, func)
ctx = torch._library.abstract_impl.AbstractImplCtx(self, func)
with torch._library.abstract_impl.set_ctx_getter(lambda: ctx), self:
result = maybe_abstract_impl(*args, **kwargs)
return result

View File

@ -26,6 +26,7 @@ import weakref
import operator
from torch.utils._stats import count
import logging
from torch._library.fake_class_registry import FakeScriptObject
from torch.overrides import TorchFunctionMode
@ -97,7 +98,7 @@ def set_proxy_slot(obj, tracer, proxy):
# We DO want to clobber proxies whenever we run an inplace operation
# on a tensor, and it affects the metadata on the proxy.
tracer.tensor_tracker[obj] = proxy
elif isinstance(obj, torch.ScriptObject):
elif isinstance(obj, (torch.ScriptObject, FakeScriptObject)):
# We DO want to clobber proxies, with a similar rationale as for tensors.
tracer.script_object_tracker[obj] = proxy
else:
@ -121,7 +122,7 @@ def has_proxy_slot(obj, tracer):
def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x):
if isinstance(obj, torch.Tensor):
tracker = tracer.tensor_tracker
elif isinstance(obj, torch.ScriptObject):
elif isinstance(obj, (torch.ScriptObject, FakeScriptObject)):
tracker = tracer.script_object_tracker
else:
assert isinstance(obj, py_sym_types), type(obj)
@ -141,7 +142,7 @@ def extract_val(val):
return snapshot_fake(val)
elif isinstance(val, py_sym_types):
return val
elif isinstance(val, torch.ScriptObject):
elif isinstance(val, (torch.ScriptObject, FakeScriptObject)):
return val
elif isinstance(val, BackwardState):
return val
@ -218,7 +219,7 @@ def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):
# NB: eagerly set meta here, so that the numbering is in order
set_meta(proxy, e)
set_proxy_slot(e, tracer, lambda: proxy)
elif isinstance(e, torch.ScriptObject):
elif isinstance(e, (torch.ScriptObject, FakeScriptObject)):
set_proxy_slot(e, tracer, proxy)
set_meta(proxy, e)
elif isinstance(e, (tuple, list)):
@ -337,7 +338,7 @@ def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs):
f_flat_args_kwargs = [
(
fetch_object_proxy(tracer)(x)
if isinstance(x, (torch.Tensor, torch.ScriptObject))
if isinstance(x, (torch.Tensor, torch.ScriptObject, FakeScriptObject))
else x
)
for x in flat_args_kwargs
@ -591,7 +592,7 @@ class PythonKeyTracer(Tracer):
return get_proxy_slot(e, self, e, lambda e: e.proxy)
elif isinstance(e, (torch.SymInt, torch.SymFloat, torch.SymBool)):
return get_proxy_slot(e, self, e, lambda e: e())
elif isinstance(e, torch.ScriptObject):
elif isinstance(e, (torch.ScriptObject, FakeScriptObject)):
return get_proxy_slot(e, self, e)
else:
return e
@ -659,6 +660,11 @@ def wrap_key(f, tensors, tracer, pre_dispatch: bool):
lambda t: get_proxy_slot(t, tracer, t, lambda x: x.proxy),
out
)
out = pytree.tree_map_only(
(torch.ScriptObject, FakeScriptObject),
lambda t: get_proxy_slot(t, tracer, t, lambda x: x),
out
)
out = pytree.tree_map_only(
(SymInt, SymFloat, SymBool),
lambda t: get_proxy_slot(t, tracer)(),
@ -1182,7 +1188,10 @@ def make_fx(f,
# NB: don't match on bools
elif type(x) is int and tracing_mode == "symbolic":
return shape_env.create_symintnode(shape_env.create_symbol(x, source, positive=None), hint=x, source=source)
elif isinstance(x, torch.ScriptObject):
return torch._library.fake_class_registry.to_fake_obj(fake_tensor_mode, x)
assert not isinstance(x, FakeScriptObject), f"ScriptObject {x} has been fakified. Cannot wrap_fake it again."
return x
sym_mode = proxy_mode.sym_mode

View File

@ -1,29 +1,32 @@
import torch
from torch.testing._internal.common_utils import find_library_location, IS_WINDOWS
if IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
else:
lib_file_path = find_library_location("libtorchbind_test.so")
torch.ops.load_library(str(lib_file_path))
@torch.library.impl_abstract("_TorchScriptTesting::takes_foo_python_meta")
def fake_takes_foo(foo, z):
return foo.add_tensor(z)
def register_if_not(qualname):
entry = torch._library.simple_registry.singleton.find(qualname)
if entry.abstract_impl.kernel is not None:
return torch.library.impl_abstract(qualname)
else:
def dummy_wrapper(fn):
return fn
return dummy_wrapper
@torch.library.impl_abstract("_TorchScriptTesting::queue_pop")
def fake_queue_pop(tq):
return tq.pop()
# put these under a function because the corresponding library might not be loaded yet.
def register_fake_operators():
@register_if_not("_TorchScriptTesting::takes_foo_python_meta")
def fake_takes_foo(foo, z):
return foo.add_tensor(z)
@register_if_not("_TorchScriptTesting::queue_pop")
def fake_queue_pop(tq):
return tq.pop()
@torch.library.impl_abstract("_TorchScriptTesting::queue_push")
def fake_queue_push(tq, x):
return tq.push(x)
@register_if_not("_TorchScriptTesting::queue_push")
def fake_queue_push(tq, x):
return tq.push(x)
@torch.library.impl_abstract("_TorchScriptTesting::queue_size")
def fake_queue_size(tq, x):
return tq.size()
@register_if_not("_TorchScriptTesting::queue_size")
def fake_queue_size(tq, x):
return tq.size()