mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67236 cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Test Plan: Imported from OSS Reviewed By: rohan-varma Differential Revision: D31916706 Pulled By: mrshenli fbshipit-source-id: f3d2bcd938a384ec297f4094831c69d4059316bb
122 lines
3.7 KiB
C++
122 lines
3.7 KiB
C++
#pragma once
|
|
|
|
#include <c10d/ProcessGroup.hpp>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
namespace c10d {
|
|
|
|
// PyProcessGroup is a pybind11 trampoline class to allow a Python
|
|
// class to inherit from torch.distributed.ProcessGroup
|
|
class PyProcessGroup : public ProcessGroup {
|
|
public:
|
|
// PyWork is a pybind11 trampoline class to allow a Python
|
|
// class to inherit from torch.distributed.Work
|
|
class PyWork : public ProcessGroup::Work {
|
|
public:
|
|
PyWork() = default;
|
|
|
|
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
|
|
PYBIND11_OVERRIDE(
|
|
bool, /* Return type */
|
|
ProcessGroup::Work, /* Parent class */
|
|
wait, /* Name of function in C++ */
|
|
timeout);
|
|
}
|
|
};
|
|
|
|
using ProcessGroup::ProcessGroup;
|
|
|
|
const std::string getBackendName() const override {
|
|
PYBIND11_OVERRIDE_PURE(
|
|
std::string, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
getBackendName, /* Name of function in C++ */
|
|
);
|
|
}
|
|
|
|
c10::intrusive_ptr<ProcessGroup::Work> allgather(
|
|
std::vector<std::vector<at::Tensor>>& outputTensors,
|
|
std::vector<at::Tensor>& inputTensors,
|
|
const AllgatherOptions& opts = AllgatherOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
allgather, /* Name of function in C++ */
|
|
outputTensors,
|
|
inputTensors,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
|
|
std::vector<at::Tensor>& tensors,
|
|
const AllreduceOptions& opts = AllreduceOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
allreduce, /* Name of function in C++ */
|
|
tensors,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<ProcessGroup::Work> barrier(
|
|
const BarrierOptions& opts = BarrierOptions()) {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
barrier, /* Name of function in C++ */
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<ProcessGroup::Work> broadcast(
|
|
std::vector<at::Tensor>& tensors,
|
|
const BroadcastOptions& opts = BroadcastOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
broadcast, /* Name of function in C++ */
|
|
tensors,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
|
|
std::vector<at::Tensor>& outputTensors,
|
|
std::vector<std::vector<at::Tensor>>& inputTensors,
|
|
const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
reduce_scatter, /* Name of function in C++ */
|
|
outputTensors,
|
|
inputTensors,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<ProcessGroup::Work> send(
|
|
std::vector<at::Tensor>& tensors,
|
|
int dstRank,
|
|
int tag) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
send, /* Name of function in C++ */
|
|
tensors,
|
|
dstRank,
|
|
tag);
|
|
}
|
|
|
|
c10::intrusive_ptr<ProcessGroup::Work> recv(
|
|
std::vector<at::Tensor>& tensors,
|
|
int srcRank,
|
|
int tag) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
recv, /* Name of function in C++ */
|
|
tensors,
|
|
srcRank,
|
|
tag);
|
|
}
|
|
};
|
|
|
|
} // namespace c10d
|