Make DispatchKeySet serializable; add __eq__ (#152732)

These seem like reasonable things to add. Also fixes a bug in vLLM for
me.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152732
Approved by: https://github.com/bdhirsh
This commit is contained in:
rzou 2025-05-02 14:00:25 -07:00 committed by PyTorch MergeBot
parent 792736f9ac
commit 762844355e
3 changed files with 31 additions and 1 deletions

View File

@ -394,7 +394,7 @@ class DispatchKeySet final {
bool empty() const {
return repr_ == 0;
}
uint64_t raw_repr() {
uint64_t raw_repr() const {
return repr_;
}

View File

@ -2,6 +2,7 @@
# ruff: noqa: F841
import logging
import pickle
import sys
import tempfile
import unittest
@ -226,6 +227,20 @@ class TestPythonRegistration(TestCase):
torch.ops.custom.sum.default(a)
self.assertTrue(meta_is_called)
def test_dispatchkeyset_pickle(self) -> None:
keyset = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
serialized = pickle.dumps(keyset)
new_keyset = pickle.loads(serialized)
self.assertEqual(new_keyset, keyset)
def test_dispatchkeyset_eq(self) -> None:
a = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
b = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
c = torch._C.DispatchKeySet(torch._C.DispatchKey.CPU)
self.assertTrue(a == b)
self.assertFalse(a != b)
self.assertTrue(a != c)
def test_override_aten_ops_with_multiple_libraries(self) -> None:
x = torch.tensor([1, 2])
with _scoped_library("aten", "IMPL") as my_lib2:

View File

@ -772,6 +772,21 @@ void initDispatchBindings(PyObject* module) {
})
.def("has", &c10::DispatchKeySet::has)
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); })
.def(
"__eq__",
[](c10::DispatchKeySet self, c10::DispatchKeySet other) {
return self.raw_repr() == other.raw_repr();
})
.def(py::pickle(
[](const c10::DispatchKeySet&
obj) { // __getstate__ : creates tuple of state
return py::make_tuple(obj.raw_repr());
},
[](const py::tuple& t) { // __setstate__ : restores state from tuple
TORCH_CHECK(
t.size() == 1, "__setstate__ expected tuple with one element");
return c10::DispatchKeySet::from_raw_repr(t[0].cast<uint64_t>());
}))
.def_static("from_raw_repr", &c10::DispatchKeySet::from_raw_repr);
m.attr("_dispatch_autogradother_backends") =