mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make DispatchKeySet serializable; add __eq__ (#152732)
These seem like reasonable things to add. Also fixes a bug in vLLM for me. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/152732 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
792736f9ac
commit
762844355e
|
|
@ -394,7 +394,7 @@ class DispatchKeySet final {
|
|||
bool empty() const {
|
||||
return repr_ == 0;
|
||||
}
|
||||
uint64_t raw_repr() {
|
||||
uint64_t raw_repr() const {
|
||||
return repr_;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
# ruff: noqa: F841
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
|
@ -226,6 +227,20 @@ class TestPythonRegistration(TestCase):
|
|||
torch.ops.custom.sum.default(a)
|
||||
self.assertTrue(meta_is_called)
|
||||
|
||||
def test_dispatchkeyset_pickle(self) -> None:
|
||||
keyset = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
|
||||
serialized = pickle.dumps(keyset)
|
||||
new_keyset = pickle.loads(serialized)
|
||||
self.assertEqual(new_keyset, keyset)
|
||||
|
||||
def test_dispatchkeyset_eq(self) -> None:
|
||||
a = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
|
||||
b = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
|
||||
c = torch._C.DispatchKeySet(torch._C.DispatchKey.CPU)
|
||||
self.assertTrue(a == b)
|
||||
self.assertFalse(a != b)
|
||||
self.assertTrue(a != c)
|
||||
|
||||
def test_override_aten_ops_with_multiple_libraries(self) -> None:
|
||||
x = torch.tensor([1, 2])
|
||||
with _scoped_library("aten", "IMPL") as my_lib2:
|
||||
|
|
|
|||
|
|
@ -772,6 +772,21 @@ void initDispatchBindings(PyObject* module) {
|
|||
})
|
||||
.def("has", &c10::DispatchKeySet::has)
|
||||
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); })
|
||||
.def(
|
||||
"__eq__",
|
||||
[](c10::DispatchKeySet self, c10::DispatchKeySet other) {
|
||||
return self.raw_repr() == other.raw_repr();
|
||||
})
|
||||
.def(py::pickle(
|
||||
[](const c10::DispatchKeySet&
|
||||
obj) { // __getstate__ : creates tuple of state
|
||||
return py::make_tuple(obj.raw_repr());
|
||||
},
|
||||
[](const py::tuple& t) { // __setstate__ : restores state from tuple
|
||||
TORCH_CHECK(
|
||||
t.size() == 1, "__setstate__ expected tuple with one element");
|
||||
return c10::DispatchKeySet::from_raw_repr(t[0].cast<uint64_t>());
|
||||
}))
|
||||
.def_static("from_raw_repr", &c10::DispatchKeySet::from_raw_repr);
|
||||
|
||||
m.attr("_dispatch_autogradother_backends") =
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user