mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add torch._library.register_fake_class to fakify torchBind class (#122622)
This PR only adds abstract class registration logic without touching existing tests so they still trace with real script object. The added tests are only for registration APIs and test error messages. Our design is that the abstract implementation should be in Python. This is much better in terms of usability. But this also has implications for custom op that takes script object as input, which is detailed later in this stack. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122622 Approved by: https://github.com/zou3519 ghstack dependencies: #122619, #122620, #122621
This commit is contained in:
parent
46c7235406
commit
c77352b5cc
|
|
@ -194,6 +194,10 @@ struct TensorQueue : torch::CustomClassHolder {
|
|||
}
|
||||
return ret;
|
||||
}
|
||||
std::vector<at::Tensor> get_raw_queue() {
|
||||
std::vector<at::Tensor> raw_queue(queue_.begin(), queue_.end());
|
||||
return raw_queue;
|
||||
}
|
||||
|
||||
private:
|
||||
std::deque<at::Tensor> queue_;
|
||||
|
|
@ -563,6 +567,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
|
|||
.def("top", &TensorQueue::top)
|
||||
.def("size", &TensorQueue::size)
|
||||
.def("clone_queue", &TensorQueue::clone_queue)
|
||||
.def("get_raw_queue", &TensorQueue::get_raw_queue)
|
||||
.def_pickle(
|
||||
// __getstate__
|
||||
[](const c10::intrusive_ptr<TensorQueue>& self)
|
||||
|
|
|
|||
|
|
@ -1,21 +1,107 @@
|
|||
# Owner(s): ["oncall: export"]
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.testing._internal.torchbind_impls # noqa: F401
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch.export import export
|
||||
from torch.export._trace import _export
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import (
|
||||
find_library_location,
|
||||
instantiate_parametrized_tests,
|
||||
IS_FBCODE,
|
||||
IS_MACOS,
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
parametrize,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.torchbind_impls import register_fake_operators
|
||||
|
||||
|
||||
def load_torchbind_test_lib():
|
||||
if IS_SANDCASTLE or IS_FBCODE:
|
||||
torch.ops.load_library("//caffe2/test/cpp/jit:test_custom_class_registrations")
|
||||
elif IS_MACOS:
|
||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
||||
else:
|
||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
||||
if IS_WINDOWS:
|
||||
lib_file_path = find_library_location("torchbind_test.dll")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
|
||||
register_fake_operators()
|
||||
|
||||
|
||||
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
|
||||
class TestExportTorchbind(TestCase):
|
||||
def setUp(self):
|
||||
load_torchbind_test_lib()
|
||||
|
||||
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
||||
class FakeFoo:
|
||||
def __init__(self, x: int, y: int):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
@classmethod
|
||||
def from_real(cls, foo):
|
||||
(x, y), _ = foo.__getstate__()
|
||||
return cls(x, y)
|
||||
|
||||
def add_tensor(self, z):
|
||||
return (self.x + self.y) * z
|
||||
|
||||
test = self
|
||||
test.tq_push_counter = 0
|
||||
test.tq_pop_counter = 0
|
||||
test.tq_size_counter = 0
|
||||
|
||||
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
|
||||
class FakeTensorQueue:
|
||||
def __init__(self, q):
|
||||
self.queue = q
|
||||
|
||||
@classmethod
|
||||
def from_real(cls, real_tq):
|
||||
ctx = torch.library.get_ctx()
|
||||
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
|
||||
return cls(fake_queue)
|
||||
|
||||
def push(self, x):
|
||||
test.tq_push_counter += 1
|
||||
self.queue.append(x)
|
||||
|
||||
def pop(self):
|
||||
test.tq_pop_counter += 1
|
||||
return self.queue.pop(0)
|
||||
|
||||
def size(self):
|
||||
test.tq_size_counter += 1
|
||||
return len(self.queue)
|
||||
|
||||
def tearDown(self):
|
||||
torch._library.fake_class_registry.deregister_fake_class(
|
||||
"_TorchScriptTesting::_Foo"
|
||||
)
|
||||
torch._library.fake_class_registry.deregister_fake_class(
|
||||
"_TorchScriptTesting::_TensorQueue"
|
||||
)
|
||||
|
||||
def _assertEqualSkipScriptObject(self, exp, actual):
|
||||
flat_exp = pytree.tree_leaves(exp)
|
||||
flat_actual = pytree.tree_leaves(actual)
|
||||
self.assertEqual(len(flat_exp), len(flat_actual))
|
||||
for a, b in zip(flat_exp, flat_actual):
|
||||
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
|
||||
continue
|
||||
self.assertEqual(a, b)
|
||||
|
||||
def _test_export_same_as_eager(
|
||||
self, f, args, kwargs=None, strict=True, pre_dispatch=False
|
||||
):
|
||||
|
|
@ -351,6 +437,189 @@ def forward(self, arg0_1, attr, arg1_1):
|
|||
return (getitem_3, add_1)""", # noqa: B950
|
||||
)
|
||||
|
||||
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
|
||||
def test_make_fx_tensor_queue_methods(self, make_fx_tracing_mode):
|
||||
test = self
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 2)
|
||||
self.check_tq_is_fake = True
|
||||
|
||||
def forward(self, tq, x):
|
||||
if self.check_tq_is_fake:
|
||||
test.assertTrue(isinstance(tq, FakeScriptObject))
|
||||
tq.push(x.cos())
|
||||
tq.push(x.sin())
|
||||
x_cos = tq.pop() + tq.size()
|
||||
x_sin = tq.pop() - tq.size()
|
||||
return x_sin, x_cos, tq
|
||||
|
||||
mod = Model()
|
||||
tq = torch.classes._TorchScriptTesting._TensorQueue(
|
||||
torch.empty(
|
||||
0,
|
||||
).fill_(-1)
|
||||
)
|
||||
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
|
||||
torch.empty(
|
||||
0,
|
||||
).fill_(-1)
|
||||
)
|
||||
x = torch.ones(2, 3)
|
||||
gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
|
||||
self.assertEqual(self.tq_push_counter, 2)
|
||||
self.assertEqual(self.tq_pop_counter, 2)
|
||||
self.assertEqual(self.tq_size_counter, 2)
|
||||
self.assertEqual(tq.size(), 0)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip("\n"),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
cos = torch.ops.aten.cos.default(arg1_1)
|
||||
call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'push', cos); cos = None
|
||||
sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
|
||||
call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'push', sin); sin = None
|
||||
call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
|
||||
call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
|
||||
add = torch.ops.aten.add.Tensor(call_torchbind_2, 1); call_torchbind_2 = None
|
||||
call_torchbind_4 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
|
||||
call_torchbind_5 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
|
||||
sub = torch.ops.aten.sub.Tensor(call_torchbind_4, 0); call_torchbind_4 = None
|
||||
return (sub, add, arg0_1)
|
||||
""",
|
||||
)
|
||||
mod.check_tq_is_fake = False
|
||||
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
|
||||
|
||||
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
|
||||
def test_make_fx_tensor_queue_methods_fakify_internal_states(
|
||||
self, make_fx_tracing_mode
|
||||
):
|
||||
test = self
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 2)
|
||||
self.check_tq_is_fake = True
|
||||
self.current_test = test
|
||||
|
||||
def forward(self, tq, x):
|
||||
if self.check_tq_is_fake:
|
||||
self.current_test.assertTrue(isinstance(tq, FakeScriptObject))
|
||||
x_cos = tq.pop() + tq.size() + x
|
||||
x_sin = tq.pop() - tq.size() + x
|
||||
return x_sin, x_cos, tq
|
||||
|
||||
mod = Model()
|
||||
tq = torch.classes._TorchScriptTesting._TensorQueue(
|
||||
torch.empty(
|
||||
0,
|
||||
).fill_(-1)
|
||||
)
|
||||
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
|
||||
torch.empty(
|
||||
0,
|
||||
).fill_(-1)
|
||||
)
|
||||
for _ in range(2):
|
||||
tq.push(torch.ones(2, 3))
|
||||
tq1.push(torch.ones(2, 3))
|
||||
x = torch.ones(2, 3)
|
||||
prev_size = tq.size()
|
||||
gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
|
||||
self.assertEqual(self.tq_push_counter, 0)
|
||||
self.assertEqual(self.tq_pop_counter, 2)
|
||||
self.assertEqual(self.tq_size_counter, 2)
|
||||
self.assertEqual(tq.size(), prev_size)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip("\n"),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
|
||||
call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
|
||||
add = torch.ops.aten.add.Tensor(call_torchbind, 1); call_torchbind = None
|
||||
add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = None
|
||||
call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
|
||||
call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
|
||||
sub = torch.ops.aten.sub.Tensor(call_torchbind_2, 0); call_torchbind_2 = None
|
||||
add_2 = torch.ops.aten.add.Tensor(sub, arg1_1); sub = arg1_1 = None
|
||||
return (add_2, add_1, arg0_1)
|
||||
""",
|
||||
)
|
||||
# turn off tq type checking in eager execution
|
||||
mod.check_tq_is_fake = False
|
||||
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
|
||||
self.assertEqual(tq.size(), 0)
|
||||
self.assertEqual(tq1.size(), 0)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
|
||||
class TestRegisterFakeClass(TestCase):
|
||||
def setUp(self):
|
||||
load_torchbind_test_lib()
|
||||
|
||||
def tearDown(self):
|
||||
torch._library.fake_class_registry.global_fake_class_registry.clear()
|
||||
|
||||
def test_register_fake_class_no_torch_bind_class(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "Tried to instantiate class"):
|
||||
|
||||
@torch._library.register_fake_class("_TorchScriptTesting::NOT_A_VALID_NAME")
|
||||
class Invalid:
|
||||
pass
|
||||
|
||||
def test_register_fake_class_no_from_real(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "define a classmethod from_real"):
|
||||
|
||||
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
||||
class InvalidFakeFoo:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def test_register_fake_class_from_real_not_classmethod(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "is not a classmethod"):
|
||||
|
||||
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
||||
class FakeFoo:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def from_real(self, foo_obj):
|
||||
x, y = foo_obj.__getstate__()
|
||||
return FakeFoo(x, y)
|
||||
|
||||
def test_register_fake_class_valid(self):
|
||||
class FakeFoo:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
@classmethod
|
||||
def from_real(cls, foo_obj):
|
||||
x, y = foo_obj.__getstate__()
|
||||
return cls(x, y)
|
||||
|
||||
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
|
||||
|
||||
def test_register_fake_class_duplicate_registration(self):
|
||||
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
||||
class FakeFoo:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
@classmethod
|
||||
def from_real(cls, foo_obj):
|
||||
x, y = foo_obj.__getstate__()
|
||||
return cls(x, y)
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, "already registered"):
|
||||
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestExportTorchbind)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,19 @@
|
|||
import logging
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
from torch._C import DispatchKey # @manual
|
||||
from torch._functorch._aot_autograd.utils import KNOWN_TYPES
|
||||
from torch._higher_order_ops.utils import autograd_not_implemented
|
||||
from torch._library.fake_class_registry import _ns_and_class_name, FakeScriptObject
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
||||
from torch.fx.node import has_side_effect
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# The call_torchbind operator represents a method invocation on a torchbind
|
||||
# object. The calling convention is:
|
||||
# call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs)
|
||||
|
|
@ -52,7 +56,12 @@ def enable_torchbind_tracing():
|
|||
|
||||
@call_torchbind.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def call_torchbind_impl(obj, method, *args, **kwargs):
|
||||
return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs)
|
||||
if isinstance(obj, torch.ScriptObject):
|
||||
return _orig_scriptmethod_call(getattr(obj, method), *args, **kwargs)
|
||||
elif isinstance(obj, FakeScriptObject):
|
||||
return getattr(obj.wrapped_obj, method)(*args, **kwargs)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported first arg type {type(obj)} for call_torchbind")
|
||||
|
||||
|
||||
@call_torchbind.py_impl(ProxyTorchDispatchMode)
|
||||
|
|
@ -69,6 +78,21 @@ def inner(mode, *args, **kwargs):
|
|||
)
|
||||
out = call_torchbind(*args, **kwargs)
|
||||
|
||||
obj, method, *rest_args = args
|
||||
if isinstance(obj, torch.ScriptObject):
|
||||
ns, class_name = _ns_and_class_name(
|
||||
obj._type().qualified_name() # type: ignore[attr-defined]
|
||||
)
|
||||
log.warning(
|
||||
"Tracing torchbind method %s.%s with real ScriptObject. This may"
|
||||
" cause the original object being mutated. If this is not intended,"
|
||||
' You can register a fake class with torch._library.register_fake_class("%s::%s").',
|
||||
class_name,
|
||||
method,
|
||||
ns,
|
||||
class_name,
|
||||
)
|
||||
|
||||
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
|
||||
else:
|
||||
return call_torchbind(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import torch._library.abstract_impl
|
||||
import torch._library.simple_registry
|
||||
import torch._library.utils
|
||||
|
||||
from torch._library.fake_class_registry import register_fake_class
|
||||
|
|
|
|||
|
|
@ -61,13 +61,13 @@ class AbstractImplHolder:
|
|||
meta_kernel = construct_meta_kernel(self.qualname, self)
|
||||
self.lib.impl(self.qualname, meta_kernel, "Meta")
|
||||
|
||||
def deregister_abstract_impl():
|
||||
def deregister_fake_class():
|
||||
if self.lib:
|
||||
self.lib._destroy()
|
||||
self.lib = None
|
||||
self.kernel = None
|
||||
|
||||
return RegistrationHandle(deregister_abstract_impl)
|
||||
return RegistrationHandle(deregister_fake_class)
|
||||
|
||||
|
||||
def construct_meta_kernel(
|
||||
|
|
@ -119,8 +119,9 @@ class AbstractImplCtx:
|
|||
Context object for writing abstract implementations for custom operators.
|
||||
"""
|
||||
|
||||
def __init__(self, _shape_env, _op):
|
||||
self._shape_env = _shape_env
|
||||
def __init__(self, _fake_mode, _op):
|
||||
self._fake_mode = _fake_mode
|
||||
self._shape_env = _fake_mode.shape_env
|
||||
self._op = _op
|
||||
|
||||
def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt:
|
||||
|
|
@ -204,3 +205,38 @@ class AbstractImplCtx:
|
|||
result, min=min, max=max
|
||||
)
|
||||
return result
|
||||
|
||||
def to_fake_tensor(self, tensor: torch.Tensor):
|
||||
"""
|
||||
Creates a fake tensor from a concrete tensor. Note: this is not needed for impl_abstract.
|
||||
|
||||
This is useful for register_fake_class (which is necessary for torch.compile) for custom class.
|
||||
Users need to implement a from_real method that takes a real custom object and creates a fake
|
||||
custom object. Users can use this API to create fake tensors for the tensor states in the custom object.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): A concrete tensor.
|
||||
|
||||
Example::
|
||||
>>> import torch
|
||||
>>> @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") # xdoctest: +SKIP
|
||||
... class FakeTensorQueue:
|
||||
... def __init__(self, q):
|
||||
... self.queue = q
|
||||
...
|
||||
... @classmethod
|
||||
... def from_real(cls, real_tq):
|
||||
... ctx = torch.library.get_ctx()
|
||||
... fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
|
||||
... return cls(fake_queue)
|
||||
...
|
||||
... def push(self, x):
|
||||
... self.queue.append(x)
|
||||
...
|
||||
... def pop(self):
|
||||
... return self.queue.pop(0)
|
||||
...
|
||||
... def size(self):
|
||||
... return len(self.queue)
|
||||
"""
|
||||
return self._fake_mode.from_tensor(tensor)
|
||||
|
|
|
|||
234
torch/_library/fake_class_registry.py
Normal file
234
torch/_library/fake_class_registry.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
import logging
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Protocol, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from torch._library.utils import parse_namespace
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FakeScriptObject:
|
||||
def __init__(self, wrapped_obj):
|
||||
self.wrapped_obj = wrapped_obj
|
||||
|
||||
|
||||
class HasStaticMethodFromReal(Protocol):
|
||||
@classmethod
|
||||
def from_real(cls, real_obj: torch.ScriptObject):
|
||||
pass
|
||||
|
||||
|
||||
class FakeClassRegistry:
|
||||
def __init__(self):
|
||||
self._registered_class: Dict[str, Any] = {}
|
||||
|
||||
def has_impl(self, full_qualname: str) -> bool:
|
||||
return full_qualname in self._registered_class
|
||||
|
||||
def get_impl(self, full_qualname: str) -> Any:
|
||||
self._check_registered(full_qualname)
|
||||
return self._registered_class[full_qualname]
|
||||
|
||||
def register(self, full_qualname: str, fake_class=None) -> None:
|
||||
if self.has_impl(full_qualname):
|
||||
warnings.warn(
|
||||
f"{full_qualname} is already registered. Previous fake class is overrided with {fake_class}."
|
||||
)
|
||||
self._registered_class[full_qualname] = fake_class
|
||||
|
||||
def deregister(self, full_qualname: str) -> Any:
|
||||
if not self.has_impl(full_qualname):
|
||||
raise RuntimeError(
|
||||
f"Cannot deregister {full_qualname}. Please use register_fake_class to register it first."
|
||||
f" Or do you dereigster it twice?"
|
||||
)
|
||||
self._check_registered(full_qualname)
|
||||
return self._registered_class.pop(full_qualname)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._registered_class.clear()
|
||||
|
||||
def _check_registered(self, full_qualname: str) -> None:
|
||||
if full_qualname not in self._registered_class:
|
||||
raise RuntimeError(
|
||||
f"{full_qualname} is not registered. Please use register_fake_class to register it first."
|
||||
)
|
||||
|
||||
|
||||
global_fake_class_registry = FakeClassRegistry()
|
||||
|
||||
|
||||
def to_fake_obj(fake_mode, x: torch.ScriptObject) -> FakeScriptObject:
|
||||
fake_x = _fake_obj_from_real(fake_mode, x)
|
||||
|
||||
def _call_torchbind(method_name):
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
|
||||
def wrapped(self_, *args, **kwargs):
|
||||
return call_torchbind(self_, method_name, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
fake_x_wrapped = FakeScriptObject(fake_x)
|
||||
for name in x._method_names(): # type: ignore[attr-defined]
|
||||
attr = getattr(fake_x, name, None)
|
||||
if attr:
|
||||
if not callable(attr):
|
||||
raise RuntimeError(f"Expect {name} to be a callable but got {attr}.")
|
||||
setattr(
|
||||
fake_x_wrapped,
|
||||
name,
|
||||
_call_torchbind(name).__get__(fake_x_wrapped),
|
||||
)
|
||||
else:
|
||||
log.warning("fake object of %s doesn't implement method %s.", x, name)
|
||||
return fake_x_wrapped
|
||||
|
||||
|
||||
def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal] = None):
|
||||
r"""Register a fake implementation for this class.
|
||||
|
||||
It's in the same spirit of registering a fake implementation for
|
||||
an operator but with the difference that it
|
||||
associates a fake class with the original torch bind class (registered
|
||||
with torch::class_). In this way, torch.compile can handle them properly
|
||||
in components such as Dynamo and AOTAutograd.
|
||||
|
||||
This API may be used as a decorator (see example). For the fake class, users
|
||||
are required to provide a from_real classmethod that takes a real object and
|
||||
returns an instance of the fake class. All tensors in the fake object should also
|
||||
be properly fakified with to_fake_tensor() in from_real.
|
||||
|
||||
Examples:
|
||||
# For a custom class Foo defined in test_custom_class_registration.cpp:
|
||||
TORCH_LIBRARY(_TorchScriptTesting, m) {
|
||||
m.class_<TensorQueue>("_TensorQueue")
|
||||
.def(torch::init<at::Tensor>())
|
||||
.def("push", &TensorQueue::push)
|
||||
.def("pop", &TensorQueue::pop)
|
||||
.def("top", &TensorQueue::top)
|
||||
.def("size", &TensorQueue::size)
|
||||
.def("clone_queue", &TensorQueue::clone_queue)
|
||||
.def_pickle(
|
||||
// __getstate__
|
||||
[](const c10::intrusive_ptr<TensorQueue>& self)
|
||||
-> c10::Dict<std::string, at::Tensor> {
|
||||
return self->serialize();
|
||||
},
|
||||
// __setstate__
|
||||
[](c10::Dict<std::string, at::Tensor> data)
|
||||
-> c10::intrusive_ptr<TensorQueue> {
|
||||
return c10::make_intrusive<TensorQueue>(std::move(data));
|
||||
});
|
||||
};
|
||||
# We could register a fake class FakeTensorQueue in Python as follows:
|
||||
import torch
|
||||
|
||||
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
|
||||
class FakeTensorQueue:
|
||||
def __init__(self, q):
|
||||
self.queue = q
|
||||
|
||||
@classmethod
|
||||
def from_real(cls, real_tq):
|
||||
ctx = torch.library.get_ctx()
|
||||
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.clone_queue()]
|
||||
return cls(fake_queue)
|
||||
|
||||
def push(self, x):
|
||||
self.queue.append(x)
|
||||
|
||||
def pop(self):
|
||||
return self.queue.pop(0)
|
||||
|
||||
def size(self):
|
||||
return len(self.queue)
|
||||
|
||||
"""
|
||||
|
||||
def inner(fake_class: HasStaticMethodFromReal):
|
||||
ns, name = parse_namespace(qualname)
|
||||
|
||||
# This also checks whether the refered torch::class_ exists.
|
||||
torchbind_class = torch._C._get_custom_class_python_wrapper(ns, name)
|
||||
|
||||
from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
|
||||
if not from_method:
|
||||
raise RuntimeError(f"{fake_class} doesn't define a classmethod from_real.")
|
||||
|
||||
if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod):
|
||||
raise RuntimeError(
|
||||
f"{_CONVERT_FROM_REAL_NAME} method is not a classmethod."
|
||||
)
|
||||
|
||||
global_fake_class_registry.register(_full_qual_class_name(qualname), fake_class)
|
||||
return fake_class
|
||||
|
||||
if fake_class is None:
|
||||
return inner
|
||||
return inner(fake_class)
|
||||
|
||||
|
||||
def deregister_fake_class(qualname):
|
||||
return global_fake_class_registry.deregister(_full_qual_class_name(qualname))
|
||||
|
||||
|
||||
def has_fake_class(full_qualname) -> bool:
|
||||
return global_fake_class_registry.has_impl(full_qualname)
|
||||
|
||||
|
||||
def find_fake_class(full_qualname) -> Optional[Any]:
|
||||
if not has_fake_class(full_qualname):
|
||||
return None
|
||||
return global_fake_class_registry.get_impl(full_qualname)
|
||||
|
||||
|
||||
def _full_qual_class_name(qualname: str) -> str:
|
||||
ns, name = parse_namespace(qualname)
|
||||
return "__torch__.torch.classes." + ns + "." + name
|
||||
|
||||
|
||||
# Return the namespace and class name from fully qualified name.
|
||||
def _ns_and_class_name(full_qualname: str) -> Tuple[str, str]:
|
||||
splits = full_qualname.split(".")
|
||||
assert len(splits) == 5
|
||||
_torch, torch_ns, classes, ns, class_name = splits
|
||||
return ns, class_name
|
||||
|
||||
|
||||
def _find_fake_class_for_script_object(x: torch.ScriptObject) -> Any:
|
||||
full_qualname = x._type().qualified_name() # type: ignore[attr-defined]
|
||||
ns, class_name = _ns_and_class_name(full_qualname)
|
||||
fake_class = find_fake_class(full_qualname)
|
||||
if fake_class is None:
|
||||
raise RuntimeError(
|
||||
f" ScriptObject's {full_qualname} haven't registered a fake class."
|
||||
f" Please use register_fake_class({ns}::{class_name}) to annotate a fake class for the script obj."
|
||||
f" Specifically, create a python class that implements a fake version for all the methods"
|
||||
f" that're used in the program and put annotated class in the program e.g. after loading the library."
|
||||
f" The fake methods can be written in the same way as a meta kernel for an operator but need to additionally"
|
||||
f" simulate the object's states. Be sure to add a {_CONVERT_FROM_REAL_NAME} classmethod"
|
||||
f" to enable creating a fake obj from a real one."
|
||||
)
|
||||
return fake_class
|
||||
|
||||
|
||||
_CONVERT_FROM_REAL_NAME = "from_real"
|
||||
|
||||
|
||||
def _fake_obj_from_real(fake_mode, x) -> Any:
|
||||
fake_class = _find_fake_class_for_script_object(x)
|
||||
|
||||
from_real_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
|
||||
if not from_real_method:
|
||||
raise RuntimeError(
|
||||
f"{fake_class} must define a classmethod {_CONVERT_FROM_REAL_NAME}"
|
||||
f" that converts the real object to the fake object."
|
||||
)
|
||||
|
||||
# from_real defined by user need the ctx to fakify the tensor states.
|
||||
ctx = torch._library.abstract_impl.AbstractImplCtx(fake_mode, None)
|
||||
with torch._library.abstract_impl.set_ctx_getter(lambda: ctx):
|
||||
return fake_class.from_real(x)
|
||||
|
|
@ -1359,7 +1359,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
func.name()
|
||||
).abstract_impl.kernel
|
||||
if maybe_abstract_impl:
|
||||
ctx = torch._library.abstract_impl.AbstractImplCtx(self.shape_env, func)
|
||||
ctx = torch._library.abstract_impl.AbstractImplCtx(self, func)
|
||||
with torch._library.abstract_impl.set_ctx_getter(lambda: ctx), self:
|
||||
result = maybe_abstract_impl(*args, **kwargs)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import weakref
|
|||
import operator
|
||||
from torch.utils._stats import count
|
||||
import logging
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
|
|
@ -97,7 +98,7 @@ def set_proxy_slot(obj, tracer, proxy):
|
|||
# We DO want to clobber proxies whenever we run an inplace operation
|
||||
# on a tensor, and it affects the metadata on the proxy.
|
||||
tracer.tensor_tracker[obj] = proxy
|
||||
elif isinstance(obj, torch.ScriptObject):
|
||||
elif isinstance(obj, (torch.ScriptObject, FakeScriptObject)):
|
||||
# We DO want to clobber proxies, with a similar rationale as for tensors.
|
||||
tracer.script_object_tracker[obj] = proxy
|
||||
else:
|
||||
|
|
@ -121,7 +122,7 @@ def has_proxy_slot(obj, tracer):
|
|||
def get_proxy_slot(obj, tracer, default=no_default, transform=lambda x: x):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tracker = tracer.tensor_tracker
|
||||
elif isinstance(obj, torch.ScriptObject):
|
||||
elif isinstance(obj, (torch.ScriptObject, FakeScriptObject)):
|
||||
tracker = tracer.script_object_tracker
|
||||
else:
|
||||
assert isinstance(obj, py_sym_types), type(obj)
|
||||
|
|
@ -141,7 +142,7 @@ def extract_val(val):
|
|||
return snapshot_fake(val)
|
||||
elif isinstance(val, py_sym_types):
|
||||
return val
|
||||
elif isinstance(val, torch.ScriptObject):
|
||||
elif isinstance(val, (torch.ScriptObject, FakeScriptObject)):
|
||||
return val
|
||||
elif isinstance(val, BackwardState):
|
||||
return val
|
||||
|
|
@ -218,7 +219,7 @@ def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):
|
|||
# NB: eagerly set meta here, so that the numbering is in order
|
||||
set_meta(proxy, e)
|
||||
set_proxy_slot(e, tracer, lambda: proxy)
|
||||
elif isinstance(e, torch.ScriptObject):
|
||||
elif isinstance(e, (torch.ScriptObject, FakeScriptObject)):
|
||||
set_proxy_slot(e, tracer, proxy)
|
||||
set_meta(proxy, e)
|
||||
elif isinstance(e, (tuple, list)):
|
||||
|
|
@ -337,7 +338,7 @@ def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs):
|
|||
f_flat_args_kwargs = [
|
||||
(
|
||||
fetch_object_proxy(tracer)(x)
|
||||
if isinstance(x, (torch.Tensor, torch.ScriptObject))
|
||||
if isinstance(x, (torch.Tensor, torch.ScriptObject, FakeScriptObject))
|
||||
else x
|
||||
)
|
||||
for x in flat_args_kwargs
|
||||
|
|
@ -591,7 +592,7 @@ class PythonKeyTracer(Tracer):
|
|||
return get_proxy_slot(e, self, e, lambda e: e.proxy)
|
||||
elif isinstance(e, (torch.SymInt, torch.SymFloat, torch.SymBool)):
|
||||
return get_proxy_slot(e, self, e, lambda e: e())
|
||||
elif isinstance(e, torch.ScriptObject):
|
||||
elif isinstance(e, (torch.ScriptObject, FakeScriptObject)):
|
||||
return get_proxy_slot(e, self, e)
|
||||
else:
|
||||
return e
|
||||
|
|
@ -659,6 +660,11 @@ def wrap_key(f, tensors, tracer, pre_dispatch: bool):
|
|||
lambda t: get_proxy_slot(t, tracer, t, lambda x: x.proxy),
|
||||
out
|
||||
)
|
||||
out = pytree.tree_map_only(
|
||||
(torch.ScriptObject, FakeScriptObject),
|
||||
lambda t: get_proxy_slot(t, tracer, t, lambda x: x),
|
||||
out
|
||||
)
|
||||
out = pytree.tree_map_only(
|
||||
(SymInt, SymFloat, SymBool),
|
||||
lambda t: get_proxy_slot(t, tracer)(),
|
||||
|
|
@ -1182,7 +1188,10 @@ def make_fx(f,
|
|||
# NB: don't match on bools
|
||||
elif type(x) is int and tracing_mode == "symbolic":
|
||||
return shape_env.create_symintnode(shape_env.create_symbol(x, source, positive=None), hint=x, source=source)
|
||||
elif isinstance(x, torch.ScriptObject):
|
||||
return torch._library.fake_class_registry.to_fake_obj(fake_tensor_mode, x)
|
||||
|
||||
assert not isinstance(x, FakeScriptObject), f"ScriptObject {x} has been fakified. Cannot wrap_fake it again."
|
||||
return x
|
||||
|
||||
sym_mode = proxy_mode.sym_mode
|
||||
|
|
|
|||
|
|
@ -1,29 +1,32 @@
|
|||
import torch
|
||||
from torch.testing._internal.common_utils import find_library_location, IS_WINDOWS
|
||||
|
||||
if IS_WINDOWS:
|
||||
lib_file_path = find_library_location("torchbind_test.dll")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
else:
|
||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
|
||||
|
||||
@torch.library.impl_abstract("_TorchScriptTesting::takes_foo_python_meta")
|
||||
def fake_takes_foo(foo, z):
|
||||
return foo.add_tensor(z)
|
||||
def register_if_not(qualname):
|
||||
entry = torch._library.simple_registry.singleton.find(qualname)
|
||||
if entry.abstract_impl.kernel is not None:
|
||||
return torch.library.impl_abstract(qualname)
|
||||
else:
|
||||
|
||||
def dummy_wrapper(fn):
|
||||
return fn
|
||||
|
||||
return dummy_wrapper
|
||||
|
||||
|
||||
@torch.library.impl_abstract("_TorchScriptTesting::queue_pop")
|
||||
def fake_queue_pop(tq):
|
||||
return tq.pop()
|
||||
# put these under a function because the corresponding library might not be loaded yet.
|
||||
def register_fake_operators():
|
||||
@register_if_not("_TorchScriptTesting::takes_foo_python_meta")
|
||||
def fake_takes_foo(foo, z):
|
||||
return foo.add_tensor(z)
|
||||
|
||||
@register_if_not("_TorchScriptTesting::queue_pop")
|
||||
def fake_queue_pop(tq):
|
||||
return tq.pop()
|
||||
|
||||
@torch.library.impl_abstract("_TorchScriptTesting::queue_push")
|
||||
def fake_queue_push(tq, x):
|
||||
return tq.push(x)
|
||||
@register_if_not("_TorchScriptTesting::queue_push")
|
||||
def fake_queue_push(tq, x):
|
||||
return tq.push(x)
|
||||
|
||||
|
||||
@torch.library.impl_abstract("_TorchScriptTesting::queue_size")
|
||||
def fake_queue_size(tq, x):
|
||||
return tq.size()
|
||||
@register_if_not("_TorchScriptTesting::queue_size")
|
||||
def fake_queue_size(tq, x):
|
||||
return tq.size()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user