mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This adds the differentiable collective -- all_to_all_single_grad. This is the initial proof of concept PR and I will be adding the remaining collectives in follow up PRs. This adds a new function called `all_to_all_single_autograd` which is the autograd variant of `all_to_all_single`. For backwards compatibility + initial testing we wanted to make the autograd variant separate to avoid regressions. This uses `autograd::Function` to register an Autograd op that calls the original `_c10d_functional::all_to_all_single` via the dispatcher. This works with compile and inductor as opposed to the previous Python implementation that had issues. As this uses the existing `_c10d_functional` ops we don't need to register any meta functions or lowering. To avoid cudaStream issues this explicitly calls `wait_tensor` in the backward method to ensure it runs under the same stream as the async operation. This hurts performance but can be alleviated potentially using `compile`. Related work: https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/comm_ops.py Test plan: ``` pytest test/distributed/test_functional_api.py -k test_all_to_all_single_compile pytest test/distributed/test_functional_api.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/123599 Approved by: https://github.com/yifuwang
237 lines
7.2 KiB
C++
237 lines
7.2 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
namespace c10d {
|
|
|
|
// PyProcessGroup is a pybind11 trampoline class to allow a Python
|
|
// class to inherit from torch.distributed.ProcessGroup
|
|
class PyProcessGroup : public ProcessGroup {
|
|
public:
|
|
// PyWork is a pybind11 trampoline class to allow a Python
|
|
// class to inherit from torch.distributed.Work
|
|
class PyWork : public Work {
|
|
public:
|
|
PyWork() = default;
|
|
|
|
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
|
|
PYBIND11_OVERRIDE(
|
|
bool, /* Return type */
|
|
Work, /* Parent class */
|
|
wait, /* Name of function in C++ */
|
|
timeout);
|
|
}
|
|
|
|
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
|
|
// We cannot use PYBIND11_OVERRIDE because:
|
|
// 1. We have to >MANUALLY< unwrap the PyFutureWrapper and
|
|
// 2. The python name is get_future
|
|
pybind11::gil_scoped_acquire gil;
|
|
auto override =
|
|
pybind11::get_override(static_cast<const Work*>(this), "get_future");
|
|
|
|
if (override) {
|
|
py::object o = override();
|
|
auto futWrapper =
|
|
o.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>();
|
|
return futWrapper->fut;
|
|
}
|
|
|
|
return Work::getFuture();
|
|
}
|
|
};
|
|
|
|
using ProcessGroup::ProcessGroup;
|
|
|
|
const std::string getBackendName() const override {
|
|
PYBIND11_OVERRIDE_PURE(
|
|
std::string, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
getBackendName, /* Name of function in C++ */
|
|
);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> allgather(
|
|
std::vector<std::vector<at::Tensor>>& outputTensors,
|
|
std::vector<at::Tensor>& inputTensors,
|
|
const AllgatherOptions& opts = AllgatherOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
allgather, /* Name of function in C++ */
|
|
outputTensors,
|
|
inputTensors,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
|
|
std::vector<at::Tensor>& outputTensors,
|
|
std::vector<at::Tensor>& inputTensors,
|
|
const AllgatherOptions& opts = AllgatherOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
allgather_into_tensor_coalesced, /* Name of function in C++ */
|
|
outputTensors,
|
|
inputTensors,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> allreduce(
|
|
std::vector<at::Tensor>& tensors,
|
|
const AllreduceOptions& opts = AllreduceOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
allreduce, /* Name of function in C++ */
|
|
tensors,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> allreduce_coalesced(
|
|
std::vector<at::Tensor>& tensors,
|
|
const AllreduceCoalescedOptions& opts =
|
|
AllreduceCoalescedOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
allreduce_coalesced, /* Name of function in C++ */
|
|
tensors,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> alltoall_base(
|
|
at::Tensor& outputBuffer,
|
|
at::Tensor& inputBuffer,
|
|
std::vector<int64_t>& outputSplitSizes,
|
|
std::vector<int64_t>& inputSplitSizes,
|
|
const AllToAllOptions& opts = AllToAllOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
alltoall_base, /* Name of function in C++ */
|
|
outputBuffer,
|
|
inputBuffer,
|
|
outputSplitSizes,
|
|
inputSplitSizes,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> barrier(
|
|
const BarrierOptions& opts = BarrierOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
barrier, /* Name of function in C++ */
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> broadcast(
|
|
std::vector<at::Tensor>& tensors,
|
|
const BroadcastOptions& opts = BroadcastOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
broadcast, /* Name of function in C++ */
|
|
tensors,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> reduce_scatter(
|
|
std::vector<at::Tensor>& outputTensors,
|
|
std::vector<std::vector<at::Tensor>>& inputTensors,
|
|
const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
reduce_scatter, /* Name of function in C++ */
|
|
outputTensors,
|
|
inputTensors,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
|
|
std::vector<at::Tensor>& outputTensors,
|
|
std::vector<at::Tensor>& inputTensors,
|
|
const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
reduce_scatter_tensor_coalesced, /* Name of function in C++ */
|
|
outputTensors,
|
|
inputTensors,
|
|
opts);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> send(
|
|
std::vector<at::Tensor>& tensors,
|
|
int dstRank,
|
|
int tag) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
send, /* Name of function in C++ */
|
|
tensors,
|
|
dstRank,
|
|
tag);
|
|
}
|
|
|
|
c10::intrusive_ptr<Work> recv(
|
|
std::vector<at::Tensor>& tensors,
|
|
int srcRank,
|
|
int tag) override {
|
|
PYBIND11_OVERRIDE(
|
|
c10::intrusive_ptr<Work>, /* Return type */
|
|
ProcessGroup, /* Parent class */
|
|
recv, /* Name of function in C++ */
|
|
tensors,
|
|
srcRank,
|
|
tag);
|
|
}
|
|
};
|
|
|
|
class TORCH_PYTHON_API PythonOnCompletionHook {
|
|
public:
|
|
// Wraps a py::object hook and acquires Python GIL in dtor before
|
|
// destructing the hook object.
|
|
PythonOnCompletionHook(py::object hook) : hook_(std::move(hook)) {}
|
|
|
|
~PythonOnCompletionHook() {
|
|
py::gil_scoped_acquire ag;
|
|
hook_.dec_ref();
|
|
// Explicitly set hook_ to nullptr to prevent py::object's dtor
|
|
// to decref on the PyObject again.
|
|
// See Note [Destructing py::object] in python_ivalue.h
|
|
hook_.ptr() = nullptr;
|
|
}
|
|
|
|
void operator()(std::shared_ptr<WorkInfo> workInfo) const {
|
|
std::exception_ptr eptr;
|
|
{
|
|
py::gil_scoped_acquire acquire;
|
|
try {
|
|
hook_(workInfo);
|
|
} catch (py::error_already_set& e) {
|
|
// py::error_already_set requires GIL to destruct, take
|
|
// special care.
|
|
eptr = std::make_exception_ptr(std::runtime_error(e.what()));
|
|
e.restore();
|
|
PyErr_Clear();
|
|
} catch (std::exception& e) {
|
|
eptr = std::current_exception();
|
|
}
|
|
}
|
|
// No more Python-related stuff at this point, i.e., this
|
|
// exception can be captured and handled by PG backend.
|
|
if (eptr)
|
|
std::rethrow_exception(eptr);
|
|
}
|
|
|
|
private:
|
|
py::object hook_;
|
|
};
|
|
|
|
} // namespace c10d
|