generalize order switch ops for 1-3d (#10395)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10395

Order switch ops (NCHW2NHWC and NHWC2NCHW) were only supporting 2D images.
This diff generalizes them to 1D and 3D, and also add a unit test we didn't have.

Reviewed By: protonu

Differential Revision: D9261177

fbshipit-source-id: 56e7ec54c9a8fb71781ac1336f3f28cf024b4bda
This commit is contained in:
Jongsoo Park 2018-08-15 09:56:23 -07:00 committed by Facebook Github Bot
parent 0f05f5fb07
commit d8ff7ad6f8
3 changed files with 151 additions and 39 deletions

View File

@ -6,19 +6,35 @@ template <>
bool NHWC2NCHWOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
CAFFE_ENFORCE(X.ndim() == 4);
const int N = X.dim32(0), H = X.dim32(1), W = X.dim32(2), C = X.dim32(3);
Y->Resize(N, C, H, W);
auto ndim = X.ndim();
CAFFE_ENFORCE_GE(ndim, 3);
const int N = X.dim32(0), C = X.dim32(ndim - 1);
vector<TIndex> Y_dims(ndim);
Y_dims[0] = N;
Y_dims[1] = C;
int image_size = 1;
for (auto i = 2; i < ndim; ++i) {
Y_dims[i] = X.dim32(i - 1);
image_size *= Y_dims[i];
}
Y->Resize(Y_dims);
if (X.size() <= 0) {
return true;
}
const float* Xdata = X.data<float>();
float* Ydata = Y->mutable_data<float>();
std::array<int, 2> dims = {image_size, C};
std::array<int, 2> axes = {1, 0};
for (int n = 0; n < N; ++n) {
for (int h = 0; h < H; ++h) {
for (int w = 0; w < W; ++w) {
for (int c = 0; c < C; ++c) {
Ydata[((n * C + c) * H + h) * W + w] = *(Xdata++);
}
}
}
math::Transpose(
2,
dims.data(),
axes.data(),
Xdata + n * image_size * C,
Ydata + n * image_size * C,
&context_);
}
return true;
}
@ -27,19 +43,35 @@ template <>
bool NCHW2NHWCOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
CAFFE_ENFORCE(X.ndim() == 4);
const int N = X.dim32(0), C = X.dim32(1), H = X.dim32(2), W = X.dim32(3);
Y->Resize(N, H, W, C);
auto ndim = X.ndim();
CAFFE_ENFORCE_GE(X.ndim(), 3);
const int N = X.dim32(0), C = X.dim32(1);
vector<TIndex> Y_dims(ndim);
Y_dims[0] = N;
int image_size = 1;
for (auto i = 1; i < ndim - 1; ++i) {
Y_dims[i] = X.dim32(i + 1);
image_size *= Y_dims[i];
}
Y_dims[ndim - 1] = C;
Y->Resize(Y_dims);
if (X.size() <= 0) {
return true;
}
const float* Xdata = X.data<float>();
float* Ydata = Y->mutable_data<float>();
std::array<int, 2> dims = {C, image_size};
std::array<int, 2> axes = {1, 0};
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
for (int h = 0; h < H; ++h) {
for (int w = 0; w < W; ++w) {
Ydata[((n * H + h) * W + w) * C + c] = *(Xdata++);
}
}
}
math::Transpose(
2,
dims.data(),
axes.data(),
Xdata + n * image_size * C,
Ydata + n * image_size * C,
&context_);
}
return true;
}
@ -53,18 +85,22 @@ OPERATOR_SCHEMA(NHWC2NCHW)
.NumOutputs(1)
.TensorInferenceFunction([](const OperatorDef& /*unused*/ /*def*/,
const vector<TensorShape>& in) {
CAFFE_ENFORCE_EQ(
in[0].dims_size(), 4, "Input for NHWC2NCHW must be 4 dimensional");
CAFFE_ENFORCE_GE(
in[0].dims_size(), 3, "Input for NHWC2NCHW must be >= 3 dimensional");
vector<TensorShape> out(1);
out[0].add_dims(in[0].dims(0));
out[0].add_dims(in[0].dims(3));
out[0].add_dims(in[0].dims(1));
out[0].add_dims(in[0].dims(2));
out[0].add_dims(in[0].dims(in[0].dims_size() - 1));
for (auto i = 1; i < in[0].dims_size() - 1; ++i) {
out[0].add_dims(in[0].dims(i));
}
return out;
})
.SetDoc(R"DOC(
The operator switches the order of data in a tensor from NHWC- sample index N,
height H, width H and channels C, to the NCHW order.
height H, width H and channels C, to the NCHW order (this is for 2D images).
In general, this operator switches the order of data in a tensor from N H_1 ...
H_k C to N C H_1 ... H_k for k-dimensional features, and currently supports
k=1, 2, and 3.
)DOC")
.Input(0, "data", "The input data (Tensor) in the NHWC order.")
.Output(0, "output", "The output tensor (Tensor) in the NCHW order.");
@ -72,9 +108,24 @@ height H, width H and channels C, to the NCHW order.
OPERATOR_SCHEMA(NCHW2NHWC)
.NumInputs(1)
.NumOutputs(1)
.TensorInferenceFunction([](const OperatorDef& /*unused*/ /*def*/,
const vector<TensorShape>& in) {
CAFFE_ENFORCE_GE(
in[0].dims_size(), 3, "Input for NCHW2NHWC must be >= 3 dimensional");
vector<TensorShape> out(1);
out[0].add_dims(in[0].dims(0));
for (auto i = 2; i < in[0].dims_size(); ++i) {
out[0].add_dims(in[0].dims(i));
}
out[0].add_dims(in[0].dims(1));
return out;
})
.SetDoc(R"DOC(
The operator switches the order of data in a tensor from NCHW- sample index N,
channels C, height H and width W, to the NHWC order.
channels C, height H and width W, to the NHWC order (this is for 2D images).
In general, this operator switches the order of data in a tensor from N C H_1
... H_k to N H_1 ... H_k C for k-dimensional features, and currently supports
k=1, 2, and 3.
)DOC")
.Input(0, "data", "The input data (Tensor) in the NCHW order.")
.Output(0, "output", "The output tensor (Tensor) in the NHWC order.");

View File

@ -3,8 +3,12 @@
namespace caffe2 {
__global__ void NHWC2NCHWKernel(const int N, const int HW, const int C,
const float* X, float* Y) {
__global__ void NHWC2NCHWKernel(
const int N,
const int HW,
const int C,
const float* X,
float* Y) {
CUDA_1D_KERNEL_LOOP(i, N * HW * C) {
const int c = i % C;
const int hw = i / C % HW;
@ -13,8 +17,12 @@ __global__ void NHWC2NCHWKernel(const int N, const int HW, const int C,
}
}
__global__ void NCHW2NHWCKernel(const int N, const int C, const int HW,
const float* X, float* Y) {
__global__ void NCHW2NHWCKernel(
const int N,
const int C,
const int HW,
const float* X,
float* Y) {
CUDA_1D_KERNEL_LOOP(i, N * C * HW) {
const int hw = i % HW;
const int c = i / HW % C;
@ -27,15 +35,26 @@ template <>
bool NHWC2NCHWOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
DCHECK_EQ(X.ndim(), 4);
const int N = X.dim32(0), H = X.dim32(1), W = X.dim32(2), C = X.dim32(3);
Y->Resize(N, C, H, W);
auto ndim = X.ndim();
DCHECK_GE(ndim, 3);
const int N = X.dim32(0), C = X.dim32(ndim - 1);
vector<TIndex> Y_dims(ndim);
Y_dims[0] = N;
Y_dims[1] = C;
size_t image_size = 1;
for (auto i = 2; i < ndim; ++i) {
Y_dims[i] = X.dim32(i - 1);
image_size *= Y_dims[i];
}
Y->Resize(Y_dims);
NHWC2NCHWKernel<<<
CAFFE_GET_BLOCKS(X.size()),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
N, H * W, C, X.data<float>(), Y->template mutable_data<float>());
N, image_size, C, X.data<float>(), Y->template mutable_data<float>());
return true;
}
@ -43,15 +62,26 @@ template <>
bool NCHW2NHWCOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto* Y = Output(0);
DCHECK_EQ(X.ndim(), 4);
const int N = X.dim32(0), C = X.dim32(1), H = X.dim32(2), W = X.dim32(3);
Y->Resize(N, H, W, C);
auto ndim = X.ndim();
DCHECK_GE(X.ndim(), 3);
const int N = X.dim32(0), C = X.dim32(1);
vector<TIndex> Y_dims(ndim);
Y_dims[0] = N;
size_t image_size = 1;
for (auto i = 1; i < ndim - 1; ++i) {
Y_dims[i] = X.dim32(i + 1);
image_size *= Y_dims[i];
}
Y_dims[ndim - 1] = C;
Y->Resize(Y_dims);
NCHW2NHWCKernel<<<
CAFFE_GET_BLOCKS(X.size()),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
N, C, H * W, X.data<float>(), Y->template mutable_data<float>());
N, C, image_size, X.data<float>(), Y->template mutable_data<float>());
return true;
}

View File

@ -0,0 +1,31 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import caffe2.python.hypothesis_test_util as hu
from caffe2.python import core
from hypothesis import given
class OrderSwitchOpsTest(hu.HypothesisTestCase):
@given(X=hu.tensor(min_dim=3, max_dim=5, min_value=1, max_value=5), **hu.gcs)
def test_nchw2nhwc(self, X, gc, dc):
op = core.CreateOperator("NCHW2NHWC", ["X"], ["Y"], device_option=gc)
def nchw2nhwc_ref(X):
X_reshaped = X.transpose((0,) + tuple(range(2, X.ndim)) + (1,))
return (X_reshaped,)
self.assertReferenceChecks(gc, op, [X], nchw2nhwc_ref)
self.assertGradientChecks(gc, op, [X], 0, [0])
self.assertDeviceChecks(dc, op, [X], [0])
@given(X=hu.tensor(min_dim=3, max_dim=5, min_value=1, max_value=5), **hu.gcs)
def test_nhwc2nchw(self, X, gc, dc):
op = core.CreateOperator("NHWC2NCHW", ["X"], ["Y"], device_option=gc)
def nhwc2nchw_ref(X):
X_reshaped = X.transpose((0, X.ndim - 1) + tuple(range(1, X.ndim - 1)))
return (X_reshaped,)
self.assertReferenceChecks(gc, op, [X], nhwc2nchw_ref)
self.assertGradientChecks(gc, op, [X], 0, [0])
self.assertDeviceChecks(dc, op, [X], [0])