mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: This directory is opted-in to clang-format but is not format-clean. This blocks continuous formatting from being enabled on fbcode, and causes hassle for other codemods that leave inconsistent formatting. This diff runs clang-format, which is widely used and considered safe. If you are unhappy with the formatting of a particular block, please *accept this diff* and then in a stacked commit undo the change and wrap that code in `// clang-format off` and `// clang-format on`, or `/* clang-format off */` and `/* clang-format on */`. drop-conflicts Test Plan: sandcastleit Reviewed By: jerryzh168 Differential Revision: D22311706 fbshipit-source-id: 1ca59a82e96156a4a5dfad70ba3e64d44c5e762a
249 lines
7.6 KiB
C++
249 lines
7.6 KiB
C++
#ifndef CAFFE2_MPI_MPI_OPS_H_
|
|
#define CAFFE2_MPI_MPI_OPS_H_
|
|
|
|
#include <mpi.h>
|
|
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/mpi/mpi_common.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
// TODO(jiayq): if needed, write up the use of color and key with MPI split.
|
|
// Currently, the operator simply creates a communicator that has the
|
|
// same topology as the Caffe2 global communicator.
|
|
template <class Context>
|
|
class MPICreateCommonWorldOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
MPICreateCommonWorldOp(const OperatorDef& operator_def, Workspace* ws)
|
|
: Operator<Context>(operator_def, ws) {}
|
|
|
|
bool RunOnDevice() override {
|
|
OperatorBase::Outputs()[0]->Reset(new MPICommonWorldWrapper());
|
|
return true;
|
|
}
|
|
};
|
|
|
|
template <class Context>
|
|
class MPIBroadcastOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
MPIBroadcastOp(const OperatorDef& operator_def, Workspace* ws)
|
|
: Operator<Context>(operator_def, ws),
|
|
root_(OperatorBase::template GetSingleArgument<int>("root", 0)) {}
|
|
~MPIBroadcastOp() {}
|
|
|
|
bool RunOnDevice() override {
|
|
MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm();
|
|
CAFFE_ENFORCE(
|
|
OperatorBase::OutputIsTensorType(0, Context::GetDeviceType()),
|
|
"Output is of wrong type.");
|
|
auto* output = Output(0);
|
|
// Make sure that output is already allocated.
|
|
CAFFE_ENFORCE(
|
|
output->numel() > 0,
|
|
"Broadcast op uses in-place operation so the output "
|
|
"should be already allocated.");
|
|
MPI_CHECK(MPI_Bcast(
|
|
output->raw_mutable_data(),
|
|
output->nbytes(),
|
|
MPIDataTypeWrapper<char>::type(),
|
|
root_,
|
|
comm));
|
|
return true;
|
|
}
|
|
|
|
protected:
|
|
int root_;
|
|
};
|
|
|
|
// MPIReduceOp does Reduce using MPI. Currently, only SUM is supported.
|
|
template <typename T, class Context>
|
|
class MPIReduceOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
MPIReduceOp(const OperatorDef& operator_def, Workspace* ws)
|
|
: Operator<Context>(operator_def, ws),
|
|
root_(OperatorBase::template GetSingleArgument<int>("root", 0)) {}
|
|
~MPIReduceOp() {}
|
|
|
|
bool RunOnDevice() override {
|
|
MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm();
|
|
auto& input = Input(1);
|
|
auto* output = Output(0, input.sizes(), at::dtype<T>());
|
|
MPI_CHECK(MPI_Reduce(
|
|
const_cast<T*>(input.template data<T>()),
|
|
output->template mutable_data<T>(),
|
|
input.numel(),
|
|
MPIDataTypeWrapper<T>::type(),
|
|
MPI_SUM,
|
|
root_,
|
|
comm));
|
|
return true;
|
|
}
|
|
|
|
protected:
|
|
int root_;
|
|
};
|
|
|
|
// MPIAllgatherOp does MPIAllgather using MPI.
|
|
template <typename T, class Context>
|
|
class MPIAllgatherOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
USE_SIMPLE_CTOR_DTOR(MPIAllgatherOp);
|
|
|
|
bool RunOnDevice() override {
|
|
MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm();
|
|
auto& input = Input(1);
|
|
auto* output = Output(0);
|
|
vector<int64_t> output_dims = input.sizes().vec();
|
|
output_dims[0] *= OperatorBase::Input<MPICommonWorldWrapper>(0).size();
|
|
output->Resize(output_dims);
|
|
MPI_CHECK(MPI_Allgather(
|
|
const_cast<T*>(input.template data<T>()),
|
|
input.numel(),
|
|
MPIDataTypeWrapper<T>::type(),
|
|
output->template mutable_data<T>(),
|
|
input.numel(),
|
|
MPIDataTypeWrapper<T>::type(),
|
|
comm));
|
|
return true;
|
|
}
|
|
};
|
|
|
|
// MPIAllreduceOp does MPIAllreduce using MPI. Currently, only SUM is supported.
|
|
template <typename T, class Context>
|
|
class MPIAllreduceOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
USE_SIMPLE_CTOR_DTOR(MPIAllreduceOp);
|
|
|
|
bool RunOnDevice() override {
|
|
MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(0).comm();
|
|
auto& input = Input(1);
|
|
auto* output = Output(0, input.sizes(), at::dtype<T>());
|
|
void* source;
|
|
if (output->template mutable_data<T>() == input.template data<T>()) {
|
|
// We are doing in-place call. Special case handling.
|
|
source = MPI_IN_PLACE;
|
|
} else {
|
|
// Normal allreduce takes the source from the input.
|
|
source = const_cast<T*>(input.template data<T>());
|
|
}
|
|
MPI_CHECK(MPI_Allreduce(
|
|
source,
|
|
output->template mutable_data<T>(),
|
|
input.numel(),
|
|
MPIDataTypeWrapper<T>::type(),
|
|
MPI_SUM,
|
|
comm));
|
|
return true;
|
|
}
|
|
};
|
|
|
|
template <class Context>
|
|
class MPISendTensorOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
MPISendTensorOp(const OperatorDef& def, Workspace* ws)
|
|
: Operator<Context>(def, ws),
|
|
OP_SINGLE_ARG(int, "dst", dst_, MPI_ANY_SOURCE),
|
|
OP_SINGLE_ARG(int, "tag", tag_, MPI_ANY_TAG),
|
|
OP_SINGLE_ARG(bool, "raw_buffer", raw_buffer_, false) {
|
|
CAFFE_ENFORCE(raw_buffer_, "non-raw-buffer transfer not supported yet.");
|
|
CAFFE_ENFORCE(
|
|
dst_ != MPI_ANY_SOURCE || def.input_size() == 4,
|
|
"You should explicitly specify the to rank either via "
|
|
"argument or via input blobs.");
|
|
CAFFE_ENFORCE(
|
|
tag_ != MPI_ANY_TAG || def.input_size() == 4,
|
|
"You should explicitly specify the tag either via "
|
|
"argument or via input blobs.");
|
|
}
|
|
|
|
bool RunOnDevice() override {
|
|
MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(COMM).comm();
|
|
auto& input = Input(INPUT);
|
|
if (InputSize() == 4) {
|
|
dst_ = OperatorBase::Input<Tensor>(DST, CPU).template data<int>()[0];
|
|
tag_ = OperatorBase::Input<Tensor>(TAG, CPU).template data<int>()[0];
|
|
}
|
|
if (raw_buffer_) {
|
|
// We need to do a const cast to cope with the fact that, before OpenMPI
|
|
// 1.7, MPI_Send expects a non-const pointer although it uses it in a
|
|
// const way.
|
|
MPI_CHECK(MPI_Send(
|
|
const_cast<void*>(input.raw_data()),
|
|
input.nbytes(),
|
|
MPI_CHAR,
|
|
dst_,
|
|
tag_,
|
|
comm));
|
|
} else {
|
|
CAFFE_NOT_IMPLEMENTED;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
protected:
|
|
int dst_;
|
|
int tag_;
|
|
bool raw_buffer_;
|
|
|
|
INPUT_TAGS(COMM, INPUT, DST, TAG);
|
|
};
|
|
|
|
template <class Context>
|
|
class MPIReceiveTensorOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
MPIReceiveTensorOp(const OperatorDef& def, Workspace* ws)
|
|
: Operator<Context>(def, ws),
|
|
OP_SINGLE_ARG(int, "src", src_, MPI_ANY_SOURCE),
|
|
OP_SINGLE_ARG(int, "tag", tag_, MPI_ANY_TAG),
|
|
OP_SINGLE_ARG(bool, "raw_buffer", raw_buffer_, false) {
|
|
CAFFE_ENFORCE(raw_buffer_, "non-raw-buffer transfer not supported yet.");
|
|
}
|
|
|
|
bool RunOnDevice() override {
|
|
MPI_Comm comm = OperatorBase::Input<MPICommonWorldWrapper>(COMM).comm();
|
|
if (InputSize() == 4) {
|
|
src_ = OperatorBase::Input<Tensor>(SRC_IN, CPU).template data<int>()[0];
|
|
tag_ = OperatorBase::Input<Tensor>(TAG_IN, CPU).template data<int>()[0];
|
|
}
|
|
MPI_Status status;
|
|
if (raw_buffer_) {
|
|
auto* output = Output(OUTPUT);
|
|
MPI_CHECK(MPI_Recv(
|
|
output->raw_mutable_data(),
|
|
output->nbytes(),
|
|
MPI_CHAR,
|
|
src_,
|
|
tag_,
|
|
comm,
|
|
&status));
|
|
} else {
|
|
CAFFE_NOT_IMPLEMENTED;
|
|
}
|
|
auto* src_out = OperatorBase::Output<Tensor>(SRC_OUT, CPU);
|
|
src_out->Resize();
|
|
src_out->template mutable_data<int>()[0] = status.MPI_SOURCE;
|
|
auto* tag_out = OperatorBase::Output<Tensor>(TAG_OUT, CPU);
|
|
tag_out->Resize();
|
|
tag_out->template mutable_data<int>()[0] = status.MPI_TAG;
|
|
return true;
|
|
}
|
|
|
|
protected:
|
|
int src_;
|
|
int tag_;
|
|
bool raw_buffer_;
|
|
INPUT_TAGS(COMM, INPUT, SRC_IN, TAG_IN);
|
|
OUTPUT_TAGS(OUTPUT, SRC_OUT, TAG_OUT);
|
|
};
|
|
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_MPI_MPI_OPS_H_
|