#include #include #include #include using namespace pybind11::literals; namespace torch::distributed { namespace { const auto placement_class_docstring = R"(The base class for the Placement type, where it describes how a DTensor is placed onto the ``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout. It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``, and ``Partial``. This class is not meant to be used directly, mainly served as a typing stub. )"; } // namespace void initPlacementBindings(PyObject* module) { auto py_module = py::reinterpret_borrow(module); auto distributed_module = py_module.def_submodule("_distributed"); py::class_( distributed_module, "Placement", placement_class_docstring) .def(py::init<>()) // Allow construction of Python subclasses. .def( "is_partial", &Placement::is_partial, py::arg("reduce_op") = py::none()) .def("is_replicate", &Placement::is_replicate) .def("is_shard", &Placement::is_shard, py::arg("dim") = py::none()); py::class_(distributed_module, "Shard") .def(py::init(), py::arg("dim")) .def_readonly("dim", &Shard::dim) .def("is_shard", &Shard::is_shard, py::arg("dim") = py::none()) .def( "__eq__", [](const Shard& lhs, const Shard& rhs) { return lhs == rhs; }, py::is_operator()) // Note: we need to use dicts for pickling to match the old // dataclasses. .def(py::pickle( [](const Shard& shard) { return py::dict("dim"_a = shard.dim); }, [](const py::dict& d) { return Shard(py::cast(d["dim"])); })); py::class_(distributed_module, "StridedShard") .def( py::init(), py::arg("dim"), py::kw_only(), py::arg("split_factor")) .def_readonly("split_factor", &StridedShard::split_factor) .def("is_shard", &StridedShard::is_shard, py::arg("dim") = py::none()) .def( "__eq__", [](const StridedShard& lhs, const Shard& rhs) { return lhs == rhs; }, py::is_operator()) .def(py::pickle( [](const StridedShard& shard) { return py::dict( "dim"_a = shard.dim, "split_factor"_a = shard.split_factor); }, [](const py::dict& d) { return StridedShard( py::cast(d["dim"]), py::cast(d["split_factor"])); })); py::class_(distributed_module, "Replicate") .def(py::init()) .def("is_replicate", &Replicate::is_replicate) .def( "__eq__", [](const Replicate& lhs, const Replicate& rhs) { return lhs == rhs; }, py::is_operator()) .def(py::pickle( // I observed SIGSEGV when trying to use None as the // pickled state, though AFAICT that matches the // behavior of // object().__reduce__(). // test_placement_types.test_type_identification will repro if an // enterprising reader wants to get this fixed. [](const Replicate& repl) { return py::dict(); }, [](const py::dict&) { return Replicate(); })); py::class_(distributed_module, "Partial") .def(py::init<>()) .def(py::init>(), py::arg("reduce_op")) .def_readonly("reduce_op", &Partial::reduce_op) .def( "is_partial", &Partial::is_partial, py::arg("reduce_op") = py::none()) .def( "__eq__", [](const Partial& lhs, const Partial& rhs) { return lhs == rhs; }, py::is_operator()) .def(py::pickle( [](const Partial& part) { return py::dict("reduce_op"_a = part.reduce_op); }, [](const py::dict& d) { return Partial(py::cast(d["reduce_op"])); })); } } // namespace torch::distributed