pytorch/torch/csrc/distributed/c10d/PyProcessGroup.hpp
Shen Li 58fefa6516 Add pybind trampoline for ProcessGroup and Work (#66338)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66338

This commit exposes c10d extension API to Python land. Users can
now override c10d communication behaviors in pure Python, and no
longer needs to go through the cpp extension steps.

Test Plan: Imported from OSS

Reviewed By: rohan-varma

Differential Revision: D31514351

Pulled By: mrshenli

fbshipit-source-id: a8b94af0af7960c078e1006c29b25f7f3bd86c81
2021-10-11 06:41:06 -07:00

113 lines
3.4 KiB
C++

#pragma once
#include <c10d/ProcessGroup.hpp>
#include <torch/csrc/utils/pybind.h>
namespace c10d {
// PyProcessGroup is a pybind11 trampoline class to allow a Python
// class to inherit from torch.distributed.ProcessGroup
class PyProcessGroup : public ProcessGroup {
public:
// PyWork is a pybind11 trampoline class to allow a Python
// class to inherit from torch.distributed.Work
class PyWork : public ProcessGroup::Work {
public:
PyWork() = default;
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
PYBIND11_OVERRIDE(
bool, /* Return type */
ProcessGroup::Work, /* Parent class */
wait, /* Name of function in C++ */
timeout);
}
};
using ProcessGroup::ProcessGroup;
const std::string getBackendName() const override {
PYBIND11_OVERRIDE_PURE(
std::string, /* Return type */
ProcessGroup, /* Parent class */
getBackendName, /* Name of function in C++ */
);
}
c10::intrusive_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
ProcessGroup, /* Parent class */
allgather, /* Name of function in C++ */
outputTensors,
inputTensors,
opts);
}
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
ProcessGroup, /* Parent class */
allreduce, /* Name of function in C++ */
tensors,
opts);
}
c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
ProcessGroup, /* Parent class */
broadcast, /* Name of function in C++ */
tensors,
opts);
}
c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
ProcessGroup, /* Parent class */
reduce_scatter, /* Name of function in C++ */
outputTensors,
inputTensors,
opts);
}
c10::intrusive_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
ProcessGroup, /* Parent class */
send, /* Name of function in C++ */
tensors,
dstRank,
tag);
}
c10::intrusive_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
ProcessGroup, /* Parent class */
recv, /* Name of function in C++ */
tensors,
srcRank,
tag);
}
};
} // namespace c10d