mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This allows infra/trainers to get detailed stats about communication efficiencies without know anything about what model or distributed training paradigms have been used. This is helpful as infra/trainer package usually prefers to be as model/algorithm agnostic as possible. Therefore, we cannot assume that infra/trainer can have access to all collectives used by the model authors. This commit adds an `OnCompletion` hook to `ProcessGroupNCCL` which will be fired on every work completion event. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106988 Approved by: https://github.com/kumpera, https://github.com/H-Huang ghstack dependencies: #107140, #107141, #107160
182 lines
5.3 KiB
C++
182 lines
5.3 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
namespace c10d {
|
|
|
|
// PyProcessGroup is a pybind11 trampoline class to allow a Python
|
|
// class to inherit from torch.distributed.ProcessGroup
|
|
class PyProcessGroup : public ProcessGroup {
|
|
public:
|
|
// PyWork is a pybind11 trampoline class to allow a Python
|
|
// class to inherit from torch.distributed.Work
|
|
class PyWork : public Work {
|
|
public:
|
|
PyWork() = default;
|
|
|
|
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
|
|
PYBIND11_OVERRIDE(
|
|
bool, /* Return type */
|
|
Work, /* Parent class */
|
|
wait, /* Name of function in C++ */
|
|
timeout);
|
|
}
|
|
|
|
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
|
|
// We cannot use PYBIND11_OVERRIDE because:
|
|
// 1. We have to >MANUALLY< unwrap the PyFutureWrapper and
|
|
// 2. The python name is get_future
|
|
pybind11::gil_scoped_acquire gil;
|
|
auto override =
|
|
pybind11::get_override(static_cast<const Work*>(this), "get_future");
|
|
|
|
if (override) {
|
|
py::object o = override();
|
|
auto futWrapper =
|
|
o.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>();
|
|
return futWrapper->fut;
|
|
}
|
|
|
|
return Work::getFuture();
|
|
}
|
|
};
|
|
|
|
using ProcessGroup::ProcessGroup;
|
|
|
|
const std::string getBackendName() const override {
|
|
PYBIND11_OVERRIDE_PURE(
|
|
std::string, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
getBackendName, /* Name of function in C++ */
|
|
);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> allgather(
|
|
std::vector<std::vector<at::Tensor>>& outputTensors,
|
|
std::vector<at::Tensor>& inputTensors,
|
|
const AllgatherOptions& opts = AllgatherOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
allgather, /* Name of function in C++ */
|
|
outputTensors,
|
|
inputTensors,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> allreduce(
|
|
std::vector<at::Tensor>& tensors,
|
|
const AllreduceOptions& opts = AllreduceOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
allreduce, /* Name of function in C++ */
|
|
tensors,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> barrier(
|
|
const BarrierOptions& opts = BarrierOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
barrier, /* Name of function in C++ */
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> broadcast(
|
|
std::vector<at::Tensor>& tensors,
|
|
const BroadcastOptions& opts = BroadcastOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
broadcast, /* Name of function in C++ */
|
|
tensors,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> reduce_scatter(
|
|
std::vector<at::Tensor>& outputTensors,
|
|
std::vector<std::vector<at::Tensor>>& inputTensors,
|
|
const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
reduce_scatter, /* Name of function in C++ */
|
|
outputTensors,
|
|
inputTensors,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> send(
|
|
std::vector<at::Tensor>& tensors,
|
|
int dstRank,
|
|
int tag) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
send, /* Name of function in C++ */
|
|
tensors,
|
|
dstRank,
|
|
tag);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> recv(
|
|
std::vector<at::Tensor>& tensors,
|
|
int srcRank,
|
|
int tag) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
recv, /* Name of function in C++ */
|
|
tensors,
|
|
srcRank,
|
|
tag);
|
|
}
|
|
};
|
|
|
|
class TORCH_API PythonOnCompletionHook {
|
|
public:
|
|
// Wraps a py::object hook and acquires Python GIL in dtor before
|
|
// destructing the hook object.
|
|
PythonOnCompletionHook(py::object hook) : hook_(std::move(hook)) {}
|
|
|
|
~PythonOnCompletionHook() {
|
|
py::gil_scoped_acquire ag;
|
|
hook_.dec_ref();
|
|
// Explicitly set hook_ to nullptr to prevent py::object's dtor
|
|
// to decref on the PyObject again.
|
|
// See Note [Destructing py::object] in python_ivalue.h
|
|
hook_.ptr() = nullptr;
|
|
}
|
|
|
|
void operator()(std::shared_ptr<WorkInfo> workInfo) const {
|
|
std::exception_ptr eptr;
|
|
{
|
|
py::gil_scoped_acquire acquire;
|
|
try {
|
|
hook_(workInfo);
|
|
} catch (py::error_already_set& e) {
|
|
// py::error_already_set requires GIL to destruct, take
|
|
// special care.
|
|
eptr = std::make_exception_ptr(std::runtime_error(e.what()));
|
|
e.restore();
|
|
PyErr_Clear();
|
|
} catch (std::exception& e) {
|
|
eptr = std::current_exception();
|
|
}
|
|
}
|
|
// No more Python-related stuff at this point, i.e., this
|
|
// exception can be captured and handled by PG backend.
|
|
if (eptr)
|
|
std::rethrow_exception(eptr);
|
|
}
|
|
|
|
private:
|
|
py::object hook_;
|
|
};
|
|
|
|
} // namespace c10d
|