diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index ee132a100cc..49dafe1e3cb 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -394,7 +394,7 @@ class DispatchKeySet final { bool empty() const { return repr_ == 0; } - uint64_t raw_repr() { + uint64_t raw_repr() const { return repr_; } diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 2e6bbd406e4..9349612575d 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -2,6 +2,7 @@ # ruff: noqa: F841 import logging +import pickle import sys import tempfile import unittest @@ -226,6 +227,20 @@ class TestPythonRegistration(TestCase): torch.ops.custom.sum.default(a) self.assertTrue(meta_is_called) + def test_dispatchkeyset_pickle(self) -> None: + keyset = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU) + serialized = pickle.dumps(keyset) + new_keyset = pickle.loads(serialized) + self.assertEqual(new_keyset, keyset) + + def test_dispatchkeyset_eq(self) -> None: + a = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU) + b = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU) + c = torch._C.DispatchKeySet(torch._C.DispatchKey.CPU) + self.assertTrue(a == b) + self.assertFalse(a != b) + self.assertTrue(a != c) + def test_override_aten_ops_with_multiple_libraries(self) -> None: x = torch.tensor([1, 2]) with _scoped_library("aten", "IMPL") as my_lib2: diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 10e6b821008..bf304e981ab 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -772,6 +772,21 @@ void initDispatchBindings(PyObject* module) { }) .def("has", &c10::DispatchKeySet::has) .def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); }) + .def( + "__eq__", + [](c10::DispatchKeySet self, c10::DispatchKeySet other) { + return self.raw_repr() == other.raw_repr(); + }) + .def(py::pickle( + [](const c10::DispatchKeySet& + obj) { // __getstate__ : creates tuple of state + return py::make_tuple(obj.raw_repr()); + }, + [](const py::tuple& t) { // __setstate__ : restores state from tuple + TORCH_CHECK( + t.size() == 1, "__setstate__ expected tuple with one element"); + return c10::DispatchKeySet::from_raw_repr(t[0].cast()); + })) .def_static("from_raw_repr", &c10::DispatchKeySet::from_raw_repr); m.attr("_dispatch_autogradother_backends") =