mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
This reverts commit d73c830e23.
We have observed significant perf drop when training ResNext101 with multiple amd GPUs:
Before:
https://ci.pytorch.org/jenkins/job/caffe2-builds/job/py2-clang7-rocmdeb-ubuntu16.04-bench/1636/console
2 GPUs ResNext training got 150\~160 imgs/sec
4 GPUs ResNext training got 270\~280 imgs/sec
After:
https://ci.pytorch.org/jenkins/job/caffe2-builds/job/py2-clang7-rocmdeb-ubuntu16.04-bench/1637/console
Both 2 and 4 GPUs ResNext training drop to 110\~120 imgs/sec
Similar perf drop are seen on ResNet50 training jobs as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18680
Differential Revision: D14702941
Pulled By: bddppq
fbshipit-source-id: 828141805afc23f25c08d4a2eb6d4b99f817c128
445 lines
11 KiB
Plaintext
445 lines
11 KiB
Plaintext
#include "caffe2/core/context_gpu.h"
|
|
#include "caffe2/operators/boolean_mask_ops.h"
|
|
|
|
#include <cub/cub.cuh>
|
|
|
|
namespace caffe2 {
|
|
|
|
namespace {
|
|
__global__ void BooleanMaskCopyKernel(
|
|
const int64_t numOfOutput,
|
|
const int64_t numBytes,
|
|
const int64_t* indices,
|
|
const uint8_t* src,
|
|
uint8_t* dest) {
|
|
for (int64_t i = blockIdx.x; i < numOfOutput; i += gridDim.x) {
|
|
const auto srcBase = indices[i] * numBytes;
|
|
const auto destBase = i * numBytes;
|
|
for (int64_t j = threadIdx.x; j < numBytes; j += blockDim.x) {
|
|
dest[destBase + j] = src[srcBase + j];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
class BooleanMaskOp<CUDAContext> final : public Operator<CUDAContext> {
|
|
public:
|
|
BooleanMaskOp(const OperatorDef& operator_def, Workspace* ws)
|
|
: Operator<CUDAContext>(operator_def, ws) {}
|
|
|
|
bool RunOnDevice() override {
|
|
const auto& src = Input(0);
|
|
const auto& mask = Input(1);
|
|
auto* dest = Output(0);
|
|
|
|
CAFFE_ENFORCE(src.dim() >= 1);
|
|
CAFFE_ENFORCE_EQ(mask.dim(), 1);
|
|
CAFFE_ENFORCE(src.size(0) == mask.size(0));
|
|
|
|
const auto* maskData = mask.data<bool>();
|
|
const auto outerSize = mask.size(0);
|
|
ReinitializeTensor(
|
|
&indices_, {outerSize}, at::dtype<int64_t>().device(CUDA));
|
|
auto* indicesData = indices_.mutable_data<int64_t>();
|
|
|
|
size_t numBytes = 0;
|
|
cub::CountingInputIterator<int> itr(0);
|
|
cub::DeviceSelect::Flagged(
|
|
nullptr,
|
|
numBytes,
|
|
itr,
|
|
maskData,
|
|
indicesData,
|
|
static_cast<int64_t*>(nullptr),
|
|
outerSize,
|
|
context_.cuda_stream());
|
|
|
|
auto numint64_t =
|
|
static_cast<int64_t>((numBytes + sizeof(int64_t) - 1) / sizeof(int64_t));
|
|
// allocate one more int64_t at the end of scratch for storing numOfOutput
|
|
ReinitializeTensor(
|
|
&scratch_, {numint64_t + 1}, at::dtype<int64_t>().device(CUDA));
|
|
auto* scratchData = scratch_.mutable_data<int64_t>();
|
|
auto* numOfOutputData = scratchData + numint64_t;
|
|
|
|
cub::DeviceSelect::Flagged(
|
|
static_cast<void*>(scratchData),
|
|
numBytes,
|
|
itr,
|
|
maskData,
|
|
indicesData,
|
|
numOfOutputData,
|
|
outerSize,
|
|
context_.cuda_stream());
|
|
|
|
// Copy numOfOutput from gpu to cpu
|
|
int64_t numOfOutput;
|
|
context_.CopyToCPU(1, numOfOutputData, &numOfOutput);
|
|
|
|
indices_.Resize(numOfOutput);
|
|
std::vector<int64_t> dims = src.sizes().vec();
|
|
dims[0] = numOfOutput;
|
|
dest->Resize(dims);
|
|
auto* destData = (uint8_t*)dest->raw_mutable_data(src.meta());
|
|
const auto* srcData = (uint8_t*)src.raw_data();
|
|
if (OutputSize() == 2) {
|
|
|
|
auto* indicesOut = Output(1, {numOfOutput}, at::dtype<int64_t>());
|
|
indicesOut->template mutable_data<int64_t>();
|
|
}
|
|
|
|
if (numOfOutput > 0) {
|
|
BooleanMaskCopyKernel<<<
|
|
min(numOfOutput, static_cast<int64_t>(CAFFE_MAXIMUM_NUM_BLOCKS)),
|
|
CAFFE_CUDA_NUM_THREADS,
|
|
0,
|
|
context_.cuda_stream()>>>(
|
|
numOfOutput,
|
|
src.size_from_dim(1) * src.meta().itemsize(),
|
|
indicesData,
|
|
srcData,
|
|
destData);
|
|
|
|
if (OutputSize() == 2) {
|
|
Output(1)->CopyFrom(indices_, /* async */ true);
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
Tensor indices_;
|
|
Tensor scratch_;
|
|
};
|
|
|
|
REGISTER_CUDA_OPERATOR(BooleanMask, BooleanMaskOp<CUDAContext>);
|
|
|
|
namespace {
|
|
|
|
#define minf (-1.0f * std::numeric_limits<float>::infinity())
|
|
|
|
template <typename T>
|
|
__global__ void sequenceMaskKernel(
|
|
int N,
|
|
int M,
|
|
int B,
|
|
const T* in,
|
|
const int* seq_lengths,
|
|
T fill_val,
|
|
T* out) {
|
|
if (B >= 0) {
|
|
CUDA_1D_KERNEL_LOOP(index, B * N * M) {
|
|
int k = index % M;
|
|
int j = (index - k) / M % N;
|
|
int i = (index - M * j - k) / (N * M);
|
|
|
|
int ind = N * M * i + M * j + k;
|
|
out[ind] = (k >= seq_lengths[j] ? fill_val : in[ind]);
|
|
}
|
|
} else {
|
|
CUDA_1D_KERNEL_LOOP(index, N * M) {
|
|
int i = index / M;
|
|
int j = index % M;
|
|
|
|
out[index] = (j >= seq_lengths[i] ? fill_val : in[index]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
__global__ void repeatedSequenceMaskKernel(
|
|
int N,
|
|
int M,
|
|
int D,
|
|
const T* in,
|
|
const int* seq_lengths,
|
|
T fill_val,
|
|
T* out) {
|
|
CUDA_1D_KERNEL_LOOP(index, N * M * D) {
|
|
int i = index / (D * M);
|
|
int j = (index / D) % M;
|
|
|
|
out[index] = (j >= seq_lengths[i] ? fill_val : in[index]);
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
__global__ void windowMaskKernel(
|
|
int N,
|
|
int M,
|
|
int B,
|
|
const T* in,
|
|
const int* window_centers,
|
|
const int radius,
|
|
T fill_val,
|
|
T* out) {
|
|
if (B >= 0) {
|
|
CUDA_1D_KERNEL_LOOP(index, B * N * M) {
|
|
int k = index % M;
|
|
int j = (index - k) / M % N;
|
|
int i = (index - M * j - k) / (N * M);
|
|
|
|
int ind = N * M * i + M * j + k;
|
|
out[ind] =
|
|
(k < window_centers[j] - radius || k > window_centers[j] + radius
|
|
? fill_val
|
|
: in[ind]);
|
|
}
|
|
} else {
|
|
CUDA_1D_KERNEL_LOOP(index, N * M) {
|
|
int i = index / M;
|
|
int j = index % M;
|
|
|
|
out[index] =
|
|
(j < window_centers[i] - radius || j > window_centers[i] + radius
|
|
? fill_val
|
|
: in[index]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
__global__ void
|
|
upperMaskKernel(int N, int M, int B, const T* in, T fill_val, T* out) {
|
|
if (B >= 0) {
|
|
CUDA_1D_KERNEL_LOOP(index, B * N * M) {
|
|
int k = index % M;
|
|
int j = (index - k) / M % N;
|
|
int i = (index - M * j - k) / (N * M);
|
|
|
|
int ind = N * M * i + M * j + k;
|
|
out[ind] = (k > j ? fill_val : in[ind]);
|
|
}
|
|
} else {
|
|
CUDA_1D_KERNEL_LOOP(index, N * M) {
|
|
int i = index / M;
|
|
int j = index % M;
|
|
|
|
out[index] = (j > i ? fill_val : in[index]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
__global__ void
|
|
lowerMaskKernel(int N, int M, int B, const T* in, T fill_val, T* out) {
|
|
if (B >= 0) {
|
|
CUDA_1D_KERNEL_LOOP(index, B * N * M) {
|
|
int k = index % M;
|
|
int j = (index - k) / M % N;
|
|
int i = (index - M * j - k) / (N * M);
|
|
|
|
int ind = N * M * i + M * j + k;
|
|
out[ind] = (k < j ? fill_val : in[ind]);
|
|
}
|
|
} else {
|
|
CUDA_1D_KERNEL_LOOP(index, N * M) {
|
|
int i = index / M;
|
|
int j = index % M;
|
|
|
|
out[index] = (j < i ? fill_val : in[index]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
__global__ void
|
|
upperDiagMaskKernel(int N, int M, int B, const T* in, T fill_val, T* out) {
|
|
if (B >= 0) {
|
|
CUDA_1D_KERNEL_LOOP(index, B * N * M) {
|
|
int k = index % M;
|
|
int j = (index - k) / M % N;
|
|
int i = (index - M * j - k) / (N * M);
|
|
|
|
int ind = N * M * i + M * j + k;
|
|
out[ind] = (k >= j ? fill_val : in[ind]);
|
|
}
|
|
} else {
|
|
CUDA_1D_KERNEL_LOOP(index, N * M) {
|
|
int i = index / M;
|
|
int j = index % M;
|
|
|
|
out[index] = (j >= i ? fill_val : in[index]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
__global__ void
|
|
lowerDiagMaskKernel(int N, int M, int B, const T* in, T fill_val, T* out) {
|
|
if (B >= 0) {
|
|
CUDA_1D_KERNEL_LOOP(index, B * N * M) {
|
|
int k = index % M;
|
|
int j = (index - k) / M % N;
|
|
int i = (index - M * j - k) / (N * M);
|
|
|
|
int ind = N * M * i + M * j + k;
|
|
out[ind] = (k <= j ? fill_val : in[ind]);
|
|
}
|
|
} else {
|
|
CUDA_1D_KERNEL_LOOP(index, N * M) {
|
|
int i = index / M;
|
|
int j = index % M;
|
|
|
|
out[index] = (j <= i ? fill_val : in[index]);
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
template <>
|
|
bool SequenceMaskOp<CUDAContext>::RunOnDevice() {
|
|
return DispatchHelper<TensorTypes<at::Half, float>>::call(this, Input(0));
|
|
}
|
|
|
|
template <>
|
|
template <class T>
|
|
bool SequenceMaskOp<CUDAContext>::DoRunWithType() {
|
|
const Tensor* input = &Input(0);
|
|
const Tensor* sequence_lengths = nullptr;
|
|
const Tensor* window_centers = nullptr;
|
|
|
|
if (mode_ == "sequence") {
|
|
sequence_lengths = &Input(1);
|
|
} else if (mode_ == "window") {
|
|
window_centers = &Input(1);
|
|
}
|
|
|
|
auto* output = Output(0, input->sizes(), at::dtype<T>());
|
|
|
|
const auto canonical_axis = input->canonical_axis_index(axis_);
|
|
|
|
// canonical_batch is non-negative if batching, -1 otherwise
|
|
int canonical_batch = -1;
|
|
if ((HasArgument("batch"))) {
|
|
canonical_batch = input->canonical_axis_index(batch_);
|
|
}
|
|
|
|
// make sure batch < axis
|
|
if (canonical_batch >= 0) {
|
|
CAFFE_ENFORCE_LT(canonical_batch, canonical_axis);
|
|
}
|
|
|
|
// if no batch, then left is product of dims up to axis
|
|
// otherwise, left is product of dims between batch and axis
|
|
const int left =
|
|
(canonical_batch >= 0
|
|
? input->size_between_dim(canonical_batch, canonical_axis)
|
|
: input->size_to_dim(canonical_axis));
|
|
const int right = input->size_from_dim(canonical_axis);
|
|
|
|
// product of dims from 1 to batch
|
|
const int batch_dim =
|
|
(canonical_batch >= 0
|
|
? input->size_to_dim(canonical_batch) * input->dim(canonical_batch)
|
|
: -1);
|
|
|
|
T fill_val = convert::To<float, T>(grad_ ? 0.0f : fill_val_);
|
|
if (mode_ == "sequence") {
|
|
if (HasArgument("repeat_from_axis")) {
|
|
const int canonical_repeat_from =
|
|
input->canonical_axis_index(repeat_from_);
|
|
const int repeated_dims = input->size_from_dim(canonical_repeat_from);
|
|
const int masked_dims = right / repeated_dims;
|
|
repeatedSequenceMaskKernel<<<
|
|
CAFFE_GET_BLOCKS(left * right),
|
|
CAFFE_CUDA_NUM_THREADS,
|
|
0,
|
|
context_.cuda_stream()>>>(
|
|
left,
|
|
masked_dims,
|
|
repeated_dims,
|
|
input->data<T>(),
|
|
sequence_lengths->data<int>(),
|
|
fill_val,
|
|
output->template mutable_data<T>());
|
|
} else {
|
|
sequenceMaskKernel<<<
|
|
CAFFE_GET_BLOCKS(left * right),
|
|
CAFFE_CUDA_NUM_THREADS,
|
|
0,
|
|
context_.cuda_stream()>>>(
|
|
left,
|
|
right,
|
|
batch_dim,
|
|
input->data<T>(),
|
|
sequence_lengths->data<int>(),
|
|
fill_val,
|
|
output->template mutable_data<T>());
|
|
}
|
|
} else if (mode_ == "window") {
|
|
windowMaskKernel<<<
|
|
CAFFE_GET_BLOCKS(left * right),
|
|
CAFFE_CUDA_NUM_THREADS,
|
|
0,
|
|
context_.cuda_stream()>>>(
|
|
left,
|
|
right,
|
|
batch_dim,
|
|
input->data<T>(),
|
|
window_centers->data<int>(),
|
|
radius_,
|
|
fill_val,
|
|
output->template mutable_data<T>());
|
|
} else if (mode_ == "upper") {
|
|
upperMaskKernel<<<
|
|
CAFFE_GET_BLOCKS(left * right),
|
|
CAFFE_CUDA_NUM_THREADS,
|
|
0,
|
|
context_.cuda_stream()>>>(
|
|
left,
|
|
right,
|
|
batch_dim,
|
|
input->data<T>(),
|
|
fill_val,
|
|
output->template mutable_data<T>());
|
|
} else if (mode_ == "lower") {
|
|
lowerMaskKernel<<<
|
|
CAFFE_GET_BLOCKS(left * right),
|
|
CAFFE_CUDA_NUM_THREADS,
|
|
0,
|
|
context_.cuda_stream()>>>(
|
|
left,
|
|
right,
|
|
batch_dim,
|
|
input->data<T>(),
|
|
fill_val,
|
|
output->template mutable_data<T>());
|
|
} else if (mode_ == "upperdiag") {
|
|
upperDiagMaskKernel<<<
|
|
CAFFE_GET_BLOCKS(left * right),
|
|
CAFFE_CUDA_NUM_THREADS,
|
|
0,
|
|
context_.cuda_stream()>>>(
|
|
left,
|
|
right,
|
|
batch_dim,
|
|
input->data<T>(),
|
|
fill_val,
|
|
output->template mutable_data<T>());
|
|
} else if (mode_ == "lowerdiag") {
|
|
lowerDiagMaskKernel<<<
|
|
CAFFE_GET_BLOCKS(left * right),
|
|
CAFFE_CUDA_NUM_THREADS,
|
|
0,
|
|
context_.cuda_stream()>>>(
|
|
left,
|
|
right,
|
|
batch_dim,
|
|
input->data<T>(),
|
|
fill_val,
|
|
output->template mutable_data<T>());
|
|
} else {
|
|
CAFFE_ENFORCE(false, "Unsupported mode for SequenceMaskOp!");
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
REGISTER_CUDA_OPERATOR(SequenceMask, SequenceMaskOp<CUDAContext>);
|
|
|
|
} // namespace caffe2
|