mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make DispatchKeySet serializable; add __eq__ (#152732)
These seem like reasonable things to add. Also fixes a bug in vLLM for me. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/152732 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
792736f9ac
commit
762844355e
|
|
@ -394,7 +394,7 @@ class DispatchKeySet final {
|
||||||
bool empty() const {
|
bool empty() const {
|
||||||
return repr_ == 0;
|
return repr_ == 0;
|
||||||
}
|
}
|
||||||
uint64_t raw_repr() {
|
uint64_t raw_repr() const {
|
||||||
return repr_;
|
return repr_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
# ruff: noqa: F841
|
# ruff: noqa: F841
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import pickle
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
@ -226,6 +227,20 @@ class TestPythonRegistration(TestCase):
|
||||||
torch.ops.custom.sum.default(a)
|
torch.ops.custom.sum.default(a)
|
||||||
self.assertTrue(meta_is_called)
|
self.assertTrue(meta_is_called)
|
||||||
|
|
||||||
|
def test_dispatchkeyset_pickle(self) -> None:
|
||||||
|
keyset = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
|
||||||
|
serialized = pickle.dumps(keyset)
|
||||||
|
new_keyset = pickle.loads(serialized)
|
||||||
|
self.assertEqual(new_keyset, keyset)
|
||||||
|
|
||||||
|
def test_dispatchkeyset_eq(self) -> None:
|
||||||
|
a = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
|
||||||
|
b = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
|
||||||
|
c = torch._C.DispatchKeySet(torch._C.DispatchKey.CPU)
|
||||||
|
self.assertTrue(a == b)
|
||||||
|
self.assertFalse(a != b)
|
||||||
|
self.assertTrue(a != c)
|
||||||
|
|
||||||
def test_override_aten_ops_with_multiple_libraries(self) -> None:
|
def test_override_aten_ops_with_multiple_libraries(self) -> None:
|
||||||
x = torch.tensor([1, 2])
|
x = torch.tensor([1, 2])
|
||||||
with _scoped_library("aten", "IMPL") as my_lib2:
|
with _scoped_library("aten", "IMPL") as my_lib2:
|
||||||
|
|
|
||||||
|
|
@ -772,6 +772,21 @@ void initDispatchBindings(PyObject* module) {
|
||||||
})
|
})
|
||||||
.def("has", &c10::DispatchKeySet::has)
|
.def("has", &c10::DispatchKeySet::has)
|
||||||
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); })
|
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); })
|
||||||
|
.def(
|
||||||
|
"__eq__",
|
||||||
|
[](c10::DispatchKeySet self, c10::DispatchKeySet other) {
|
||||||
|
return self.raw_repr() == other.raw_repr();
|
||||||
|
})
|
||||||
|
.def(py::pickle(
|
||||||
|
[](const c10::DispatchKeySet&
|
||||||
|
obj) { // __getstate__ : creates tuple of state
|
||||||
|
return py::make_tuple(obj.raw_repr());
|
||||||
|
},
|
||||||
|
[](const py::tuple& t) { // __setstate__ : restores state from tuple
|
||||||
|
TORCH_CHECK(
|
||||||
|
t.size() == 1, "__setstate__ expected tuple with one element");
|
||||||
|
return c10::DispatchKeySet::from_raw_repr(t[0].cast<uint64_t>());
|
||||||
|
}))
|
||||||
.def_static("from_raw_repr", &c10::DispatchKeySet::from_raw_repr);
|
.def_static("from_raw_repr", &c10::DispatchKeySet::from_raw_repr);
|
||||||
|
|
||||||
m.attr("_dispatch_autogradother_backends") =
|
m.attr("_dispatch_autogradother_backends") =
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user