Generalize PoolingOp(cuDNN) to compute 2D and 3D pooling.

Reviewed By: akyrola

Differential Revision: D5090689

fbshipit-source-id: f9f11e12adc0ee8db088f3397a8c33aa31eb5deb
This commit is contained in:
Ahmed Taei 2017-05-19 10:03:53 -07:00 committed by Facebook Github Bot
parent 1b7497807f
commit 32bf7a2c2b
2 changed files with 152 additions and 68 deletions

View File

@ -4,6 +4,45 @@
namespace caffe2 {
namespace {
template <typename T>
void setTensorDescriptor(
const int size,
const StorageOrder order,
const int N,
const int C,
const int H,
const int W,
const int D,
cudnnTensorDescriptor_t& desc) {
if (size == 4) {
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
desc,
GetCudnnTensorFormat(order),
cudnnTypeWrapper<T>::type,
N,
C,
H,
W));
} else {
vector<int> dims = {N, C, H, W, D};
vector<int> strides;
order == NCHW
? strides.insert(strides.end(), {C * H * W * D, H * W * D, W * D, D, 1})
: strides.insert(
strides.end(), {H * W * D * C, 1, W * D * C, D * C, C});
CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
desc,
cudnnTypeWrapper<T>::type,
size > 3 ? size : 4,
dims.data(),
strides.data()));
}
}
} // namespace
class CuDNNPoolOp : public ConvPoolOpBase<CUDAContext> {
public:
CuDNNPoolOp(const OperatorDef& operator_def, Workspace* ws)
@ -36,39 +75,46 @@ class CuDNNPoolOp : public ConvPoolOpBase<CUDAContext> {
bool DoRunWithType() {
auto& X = Input(0);
auto* Y = Output(0);
int N = 0, C = 0, H = 0, W = 0;
int N = 0, C = 0, H = 0, W = 0, D = 0;
int H_out = 0, W_out = 0, D_out = 0;
// cuDNN pooling support only 2 and 3 spatial dimensions.
CAFFE_ENFORCE(X.ndim() >= 4 && X.ndim() <= 5);
switch (order_) {
case StorageOrder::NHWC:
N = X.dim32(0); H = X.dim32(1); W = X.dim32(2); C = X.dim32(3);
break;
case StorageOrder::NCHW:
N = X.dim32(0); C = X.dim32(1); H = X.dim32(2); W = X.dim32(3);
break;
default:
LOG(FATAL) << "Unknown storage order: " << order_;
case StorageOrder::NHWC:
N = X.dim32(0);
H = X.dim32(1);
W = X.ndim() > 3 ? X.dim32(2) : 1;
D = X.ndim() > 4 ? X.dim32(3) : 1;
C = X.dim32(X.ndim() - 1);
ConvPoolOpBase::SetOutputSize(X, Y, C);
H_out = Y->dim32(1);
W_out = Y->ndim() > 3 ? Y->dim32(2) : 1;
D_out = Y->ndim() > 4 ? Y->dim32(3) : 1;
break;
case StorageOrder::NCHW:
N = X.dim32(0);
C = X.dim32(1);
H = X.dim32(2);
W = X.ndim() > 3 ? X.dim32(3) : 1;
D = X.ndim() > 4 ? X.dim32(4) : 1;
ConvPoolOpBase::SetOutputSize(X, Y, C);
H_out = Y->dim32(2);
W_out = Y->ndim() > 3 ? Y->dim32(3) : 1;
D_out = Y->ndim() > 4 ? Y->dim32(4) : 1;
break;
default:
LOG(FATAL) << "Unknown storage order: " << order_;
}
ConvPoolOpBase::SetOutputSize(X, Y, C);
if (cudnn_input_dims_ != X.dims()) {
// Dimensions changed; we will need to re-initialize things.
VLOG(1) << "Changing the cudnn descriptor configurations.";
cudnn_input_dims_ = X.dims();
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
bottom_desc_,
GetCudnnTensorFormat(order_),
cudnnTypeWrapper<T>::type,
N,
C,
H,
W));
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
top_desc_,
GetCudnnTensorFormat(order_),
cudnnTypeWrapper<T>::type,
N,
C,
order_ == StorageOrder::NCHW ? Y->dim32(2) : Y->dim32(1),
order_ == StorageOrder::NCHW ? Y->dim32(3) : Y->dim32(2)));
setTensorDescriptor<T>(X.ndim(), order_, N, C, H, W, D, bottom_desc_);
setTensorDescriptor<T>(
Y->ndim(), order_, N, C, H_out, W_out, D_out, top_desc_);
if (pad_t() != pad_l() || pad_l() != pad_r()) {
CAFFE_ENFORCE(
legacy_pad_ == LegacyPadding::CAFFE_LEGACY_POOLING,
@ -76,16 +122,27 @@ class CuDNNPoolOp : public ConvPoolOpBase<CUDAContext> {
"the only exception of the caffe legacy pooling case where we "
"try to preserve backward compatibility with Caffe.");
}
CUDNN_ENFORCE(cudnnSetPooling2dDescriptor(
pooling_desc_,
mode_,
CUDNN_NOT_PROPAGATE_NAN,
kernel_h(),
kernel_w(),
pad_t(),
pad_l(),
stride_h(),
stride_w()));
if (kernel_.size() == 2) {
CUDNN_ENFORCE(cudnnSetPooling2dDescriptor(
pooling_desc_,
mode_,
CUDNN_NOT_PROPAGATE_NAN,
kernel_h(),
kernel_w(),
pad_t(),
pad_l(),
stride_h(),
stride_w()));
} else {
CUDNN_ENFORCE(cudnnSetPoolingNdDescriptor(
pooling_desc_,
mode_,
CUDNN_NOT_PROPAGATE_NAN,
kernel_.size(),
kernel_.data(),
pads_.data(),
stride_.data()));
}
}
// Carry out the pooling computation.
CUDNN_ENFORCE(cudnnPoolingForward(
@ -155,40 +212,55 @@ class CuDNNPoolGradientOp : public ConvPoolOpBase<CUDAContext> {
auto& Y = Input(1);
auto& dY = Input(2);
auto* dX = Output(0);
// cuDNN pooling support only 2 and 3 spatial dimensions.
CAFFE_ENFORCE(X.ndim() >= 4 && X.ndim() <= 5);
dX->ResizeLike(X);
int N = 0, C = 0, H = 0, W = 0;
int N = 0, C = 0, H = 0, W = 0, D = 0;
int H_out = 0, W_out = 0, D_out = 0;
switch (order_) {
case StorageOrder::NHWC:
N = X.dim32(0); H = X.dim32(1); W = X.dim32(2); C = X.dim32(3);
N = X.dim32(0);
H = X.dim32(1);
W = X.ndim() > 3 ? X.dim32(2) : 1;
D = X.ndim() > 4 ? X.dim32(3) : 1;
C = X.dim32(X.ndim() - 1);
H_out = Y.dim32(1);
W_out = Y.ndim() > 3 ? Y.dim32(2) : 1;
D_out = Y.ndim() > 4 ? Y.dim32(3) : 1;
break;
case StorageOrder::NCHW:
N = X.dim32(0); C = X.dim32(1); H = X.dim32(2); W = X.dim32(3);
N = X.dim32(0);
C = X.dim32(1);
H = X.dim32(2);
W = X.ndim() > 3 ? X.dim32(3) : 1;
D = X.ndim() > 4 ? X.dim32(4) : 1;
H_out = Y.dim32(2);
W_out = Y.ndim() > 3 ? Y.dim32(3) : 1;
D_out = Y.ndim() > 4 ? Y.dim32(4) : 1;
break;
default:
LOG(FATAL) << "Unknown storage order: " << order_;
}
ConvPoolOpBase<CUDAContext>::ComputePads({H, W});
if (kernel_.size() == 1) {
ConvPoolOpBase<CUDAContext>::ComputePads({H});
} else if (kernel_.size() == 2) {
ConvPoolOpBase<CUDAContext>::ComputePads({H, W});
} else if (kernel_.size() == 3) {
ConvPoolOpBase<CUDAContext>::ComputePads({H, W, D});
} else {
CAFFE_THROW("Unsupported kernel size :", kernel_.size());
}
if (cudnn_input_dims_ != X.dims()) {
// Dimensions changed; we will need to re-initialize things.
VLOG(1) << "Changing the cudnn descriptor configurations.";
cudnn_input_dims_ = X.dims();
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
bottom_desc_,
GetCudnnTensorFormat(order_),
cudnnTypeWrapper<T>::type,
N,
C,
H,
W));
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
top_desc_,
GetCudnnTensorFormat(order_),
cudnnTypeWrapper<T>::type,
N,
C,
order_ == StorageOrder::NCHW ? Y.dim32(2) : Y.dim32(1),
order_ == StorageOrder::NCHW ? Y.dim32(3) : Y.dim32(2)));
setTensorDescriptor<T>(X.ndim(), order_, N, C, H, W, D, bottom_desc_);
setTensorDescriptor<T>(
Y.ndim(), order_, N, C, H_out, W_out, D_out, top_desc_);
if (pad_t() != pad_l() || pad_l() != pad_r()) {
CAFFE_ENFORCE(
legacy_pad_ == LegacyPadding::CAFFE_LEGACY_POOLING,
@ -196,16 +268,27 @@ class CuDNNPoolGradientOp : public ConvPoolOpBase<CUDAContext> {
"the only exception of the caffe legacy pooling case where we "
"try to preserve backward compatibility with Caffe.");
}
CUDNN_ENFORCE(cudnnSetPooling2dDescriptor(
pooling_desc_,
mode_,
CUDNN_NOT_PROPAGATE_NAN,
kernel_h(),
kernel_w(),
pad_t(),
pad_l(),
stride_h(),
stride_w()));
if (kernel_.size() == 2) {
CUDNN_ENFORCE(cudnnSetPooling2dDescriptor(
pooling_desc_,
mode_,
CUDNN_NOT_PROPAGATE_NAN,
kernel_h(),
kernel_w(),
pad_t(),
pad_l(),
stride_h(),
stride_w()));
} else {
CUDNN_ENFORCE(cudnnSetPoolingNdDescriptor(
pooling_desc_,
mode_,
CUDNN_NOT_PROPAGATE_NAN,
kernel_.size(),
kernel_.data(),
pads_.data(),
stride_.data()));
}
}
// Carry out the pooling computation.
CUDNN_ENFORCE(cudnnPoolingBackward(

View File

@ -117,9 +117,10 @@ class TestPooling(hu.HypothesisTestCase):
batch_size=st.integers(1, 3),
order=st.sampled_from(["NCHW", "NHWC"]),
method=st.sampled_from(["MaxPool", "AveragePool"]),
engine=st.sampled_from(["", "CUDNN"]),
**hu.gcs)
def test_pooling_3d(self, stride, pad, kernel, size, input_channels,
batch_size, order, method, gc, dc):
batch_size, order, method, engine, gc, dc):
assume(pad < kernel)
op = core.CreateOperator(
method,
@ -129,7 +130,7 @@ class TestPooling(hu.HypothesisTestCase):
kernels=[kernel] * 3,
pads=[pad] * 6,
order=order,
engine="",
engine=engine,
)
X = np.random.rand(
batch_size, size, size, size, input_channels).astype(np.float32)