mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
A big pain point ppl have with custom ops is that they do not accept arbitrary input/outputs. In this PR we create the concept of an "OpaqueObject" which allows users to pass arbitrary python objects into custom operators.
Some still slightly annoying parts with this implementation:
- The schema of the operator is `__torch__.torch.classes.aten.OpaqueObject` instead of whatever python type
- `@torch.library.custom_op` doesn't work.. yet?
UX:
```python
from torch._library.opaque_object import make_opaque, get_payload
# your custom python class
class OpaqueQueue:
def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
super().__init__()
self.queue = queue
self.init_tensor_ = init_tensor_
def push(self, tensor: torch.Tensor) -> None:
self.queue.append(tensor)
def pop(self) -> torch.Tensor:
if len(self.queue) > 0:
return self.queue.pop(0)
return self.init_tensor_
def size(self) -> int:
return len(self.queue)
queue = OpaqueQueue([], torch.zeros(3))
obj: torch._C.ScriptObject = make_opaque(queue)
# obj.payload stores a direct reference to this python queue object
self.assertEqual(get_payload(obj), queue)
# This is able to be passed through the dispatcher
torch.ops._TestOpaqueObject.queue_push(obj, torch.ones(3))
self.assertTrue(queue.size(), 1)
```
Authoring a custom op:
```python
lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT")
torch.library.define(
f"_TestOpaqueObject::queue_push",
"(__torch__.torch.classes.aten.OpaqueObject a, Tensor b) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl(f"{libname}::queue_push", "CompositeExplicitAutograd", lib=lib)
def push_impl(q: torch._C.ScriptObject, b: torch.Tensor) -> None:
# We can get the payload directly by get_payload(q)
queue = get_payload(q)
assert isinstance(queue, OpaqueQueue)
queue.push(b)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162660
Approved by: https://github.com/zou3519
88 lines
2.7 KiB
Python
88 lines
2.7 KiB
Python
# Owner(s): ["module: custom-operators"]
|
|
|
|
import torch
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
from torch._library.opaque_object import get_payload, make_opaque
|
|
|
|
|
|
class OpaqueQueue:
|
|
def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
|
|
super().__init__()
|
|
self.queue = queue
|
|
self.init_tensor_ = init_tensor_
|
|
|
|
def push(self, tensor: torch.Tensor) -> None:
|
|
self.queue.append(tensor)
|
|
|
|
def pop(self) -> torch.Tensor:
|
|
if len(self.queue) > 0:
|
|
return self.queue.pop(0)
|
|
return self.init_tensor_
|
|
|
|
def size(self) -> int:
|
|
return len(self.queue)
|
|
|
|
|
|
class TestOpaqueObject(TestCase):
|
|
def setUp(self):
|
|
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") # noqa: TOR901
|
|
|
|
torch.library.define(
|
|
"_TestOpaqueObject::queue_push",
|
|
"(__torch__.torch.classes.aten.OpaqueObject a, Tensor b) -> ()",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=self.lib,
|
|
)
|
|
|
|
@torch.library.impl(
|
|
"_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib
|
|
)
|
|
def push_impl(q: torch._C.ScriptObject, b: torch.Tensor) -> None:
|
|
queue = get_payload(q)
|
|
assert isinstance(queue, OpaqueQueue)
|
|
queue.push(b)
|
|
|
|
self.lib.define(
|
|
"queue_pop(__torch__.torch.classes.aten.OpaqueObject a) -> Tensor",
|
|
)
|
|
|
|
def pop_impl(q: torch._C.ScriptObject) -> torch.Tensor:
|
|
queue = get_payload(q)
|
|
assert isinstance(queue, OpaqueQueue)
|
|
return queue.pop()
|
|
|
|
self.lib.impl("queue_pop", pop_impl, "CompositeExplicitAutograd")
|
|
|
|
super().setUp()
|
|
|
|
def tearDown(self):
|
|
self.lib._destroy()
|
|
|
|
super().tearDown()
|
|
|
|
def test_creation(self):
|
|
queue = OpaqueQueue([], torch.zeros(3))
|
|
obj = make_opaque(queue)
|
|
self.assertTrue(isinstance(obj, torch._C.ScriptObject))
|
|
self.assertEqual(str(obj._type()), "__torch__.torch.classes.aten.OpaqueObject")
|
|
|
|
# obj.payload stores a direct reference to this python queue object
|
|
payload = get_payload(obj)
|
|
self.assertEqual(payload, queue)
|
|
queue.push(torch.ones(3))
|
|
self.assertEqual(payload.size(), 1)
|
|
|
|
def test_ops(self):
|
|
queue = OpaqueQueue([], torch.zeros(3))
|
|
obj = make_opaque(queue)
|
|
|
|
torch.ops._TestOpaqueObject.queue_push(obj, torch.ones(3) + 1)
|
|
self.assertEqual(queue.size(), 1)
|
|
popped = torch.ops._TestOpaqueObject.queue_pop(obj)
|
|
self.assertEqual(popped, torch.ones(3) + 1)
|
|
self.assertEqual(queue.size(), 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|