diff --git a/caffe2/operators/channel_stats_op.cc b/caffe2/operators/channel_stats_op.cc index c4ed32cb05f..af736c77f20 100644 --- a/caffe2/operators/channel_stats_op.cc +++ b/caffe2/operators/channel_stats_op.cc @@ -1,39 +1,51 @@ #include "caffe2/operators/channel_stats_op.h" + #include "caffe2/utils/eigen_utils.h" namespace caffe2 { template <> -bool ChannelStatsOp::RunOnDevice() { - const auto& X = Input(INPUT); - CAFFE_ENFORCE(X.dim() >= 3 && X.dim() <= 5); - const int N = X.dim32(0); - const int C = X.dim32(1); - const int H = X.dim32(2); - const int W = X.dim() > 3 ? X.dim32(3) : 1; - const int D = X.dim() > 4 ? X.dim32(4) : 1; - - const int sampleSize = H * W * D; - - Output(SUM)->Resize(C); - Output(SUMSQ)->Resize(C); - EigenVectorArrayMap sum( - Output(SUM)->template mutable_data(), C); - EigenVectorArrayMap sumsq( - Output(SUMSQ)->template mutable_data(), C); - - sum.setZero(); - sumsq.setZero(); - ConstEigenArrayMap X_arr(X.data(), sampleSize, N * C); - auto index = 0; - for (int n = 0; n < N; ++n) { - for (int c = 0; c < C; ++c) { - sum(c) += X_arr.col(index).sum(); - sumsq(c) += X_arr.col(index).matrix().squaredNorm(); - index++; +template <> +bool ChannelStatsOp::ComputeChannelStatsNCHW( + const int N, + const int C, + const int HxW, + const float* X, + float* sum, + float* sumsq) { + ConstEigenArrayMap X_arr(X, HxW, N * C); + for (int i = 0; i < C; ++i) { + sum[i] = X_arr.col(i).sum(); + sumsq[i] = X_arr.col(i).square().sum(); + } + for (int i = 1; i < N; ++i) { + for (int j = 0; j < C; ++j) { + const int c = i * C + j; + sum[j] += X_arr.col(c).sum(); + sumsq[j] += X_arr.col(c).square().sum(); } } + return true; +} +template <> +template <> +bool ChannelStatsOp::ComputeChannelStatsNHWC( + const int N, + const int C, + const int HxW, + const float* X, + float* sum, + float* sumsq) { + ConstEigenArrayMap X_arr(X, C, N * HxW); + EigenVectorArrayMap sum_arr(sum, C); + EigenVectorArrayMap sumsq_arr(sumsq, C); + sum_arr = X_arr.col(0); + sumsq_arr = X_arr.col(0).square(); + for (int i = 1; i < N * HxW; ++i) { + sum_arr += X_arr.col(i); + sumsq_arr += X_arr.col(i).square(); + } return true; } @@ -49,7 +61,6 @@ reduced across multiple batches and used to obtain the mean and variance across the full set of batches. Using the new mean and variance as input to SpatialBN has the effect of changing the batch size over which SpatialBN is applied. )DOC") - .Input(0, "X", "The input 4-dimensional tensor of shape NCHW") .Output( 0, @@ -61,5 +72,7 @@ has the effect of changing the batch size over which SpatialBN is applied. "sumsq", "The output 1-dimensional tensor of size C containing the sum of " "elements squared per channel."); + SHOULD_NOT_DO_GRADIENT(ChannelStats); + } // namespace caffe2 diff --git a/caffe2/operators/channel_stats_op.cu b/caffe2/operators/channel_stats_op.cu index 5243005b8d0..ae7f7d275b5 100644 --- a/caffe2/operators/channel_stats_op.cu +++ b/caffe2/operators/channel_stats_op.cu @@ -1,192 +1,117 @@ -#include "caffe2/core/context_gpu.h" #include "caffe2/operators/channel_stats_op.h" +#include "caffe2/core/context_gpu.h" +#include "caffe2/utils/math/reduce.cuh" + namespace caffe2 { namespace { -// based on "Optimizing Parallel Reduction in CUDA" by Mark Harris - -// note - volatile keyword is needed to allow doing a warp reduction without -// synchronization on recent architectures -template -__device__ void warpReduce(volatile float* sdata, unsigned int tid) { - // note - the if statements are "free" as they are resolved at compile time - if (blockSize >= 64) - sdata[tid] += sdata[tid + 32]; - if (blockSize >= 32) - sdata[tid] += sdata[tid + 16]; - if (blockSize >= 16) - sdata[tid] += sdata[tid + 8]; - if (blockSize >= 8) - sdata[tid] += sdata[tid + 4]; - if (blockSize >= 4) - sdata[tid] += sdata[tid + 2]; - if (blockSize >= 2) - sdata[tid] += sdata[tid + 1]; -} - -template -__global__ void ChannelStatsBlockKernel( - int N, - int C, - int valsPerChannel, - const float* inputData, - float* sums, - float* sumsq) { - __shared__ float sumData[blockSize]; - __shared__ float sumSqData[blockSize]; - - auto tid = threadIdx.x; - auto numBlocksPerChannel = (valsPerChannel + blockSize - 1) / blockSize; - auto localBlockIndex = blockIdx.x % numBlocksPerChannel; - auto inputIndex = (blockIdx.x / numBlocksPerChannel) * valsPerChannel + - localBlockIndex * blockSize + tid; - - sumData[tid] = 0; - sumSqData[tid] = 0; - - if (localBlockIndex * blockSize + tid < valsPerChannel) { - sumData[tid] += inputData[inputIndex]; - sumSqData[tid] += inputData[inputIndex] * inputData[inputIndex]; - } - - __syncthreads(); - if (blockSize >= 512) { - if (tid < 256) { - sumData[tid] += sumData[tid + 256]; - sumSqData[tid] += sumSqData[tid + 256]; +template +__global__ void ChannelStatsNCHWCUDAKernel( + const int N, + const int C, + const int HxW, + const T* X, + T* sum, + T* sumsq) { + __shared__ + typename BlockReduce2D::TempStorage m_storage; + __shared__ + typename BlockReduce2D::TempStorage v_storage; + const int c = blockIdx.x; + T m_val = 0; + T v_val = 0; + for (int n = threadIdx.x; n < N; n += blockDim.x) { + for (int hw = threadIdx.y; hw < HxW; hw += blockDim.y) { + const int index = (n * C + c) * HxW + hw; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + m_val += __ldg(X + index); + v_val += __ldg(X + index) * __ldg(X + index); +#else + m_val += X[index]; + v_val += X[index] * X[index]; +#endif } - __syncthreads(); } - if (blockSize >= 256) { - if (tid < 128) { - sumData[tid] += sumData[tid + 128]; - sumSqData[tid] += sumSqData[tid + 128]; - } - __syncthreads(); - } - if (blockSize >= 128) { - if (tid < 64) { - sumData[tid] += sumData[tid + 64]; - sumSqData[tid] += sumSqData[tid + 64]; - } - __syncthreads(); - } - - if (tid < 32) { - warpReduce(sumData, tid); - warpReduce(sumSqData, tid); - } - - // output block data sorted by C to simplify second reduction - if (tid == 0) { - auto n = blockIdx.x / numBlocksPerChannel / C; - auto c = (blockIdx.x / numBlocksPerChannel) % C; - auto outputIndex = (c * N + n) * numBlocksPerChannel + localBlockIndex; - sums[outputIndex] = sumData[0]; - sumsq[outputIndex] = sumSqData[0]; + m_val = BlockReduce2D(m_storage).Sum(m_val); + v_val = BlockReduce2D(v_storage).Sum(v_val); + if (threadIdx.x == 0 && threadIdx.y == 0) { + sum[c] = m_val; + sumsq[c] = v_val; } } -template -__global__ void ChannelStatsFinalSumsKernel( - int N, - int C, - int numSumsPerChannel, - const float* sumsScratch, - const float* sumsqScratch, - float* channelSums, - float* channelSumsq) { - __shared__ float sumData[blockSize]; - __shared__ float sumSqData[blockSize]; - - auto tid = threadIdx.x; - auto inputIndex = blockIdx.x * N * numSumsPerChannel + tid; - sumData[tid] = 0; - sumSqData[tid] = 0; - for (auto i = inputIndex; i < (blockIdx.x + 1) * N * numSumsPerChannel; - i += blockSize) { - sumData[tid] += sumsScratch[i]; - sumSqData[tid] += sumsqScratch[i]; +template +__global__ void ChannelStatsNHWCCUDAKernel( + const int N, + const int C, + const int HxW, + const T* X, + T* sum, + T* sumsq) { + __shared__ typename BlockReduce::TempStorage m_storage; + __shared__ typename BlockReduce::TempStorage v_storage; + const int inner_size = N * HxW; + const int c = blockIdx.x; + T m_val = 0; + T v_val = 0; + for (int i = threadIdx.x; i < inner_size; i += blockDim.x) { + const int index = i * C + c; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + m_val += __ldg(X + index); + v_val += __ldg(X + index) * __ldg(X + index); +#else + m_val += X[index]; + v_val += X[index] * X[index]; +#endif } - __syncthreads(); - if (blockSize >= 512) { - if (tid < 256) { - sumData[tid] += sumData[tid + 256]; - sumSqData[tid] += sumSqData[tid + 256]; - } - __syncthreads(); - } - if (blockSize >= 256) { - if (tid < 128) { - sumData[tid] += sumData[tid + 128]; - sumSqData[tid] += sumSqData[tid + 128]; - } - __syncthreads(); - } - if (blockSize >= 128) { - if (tid < 64) { - sumData[tid] += sumData[tid + 64]; - sumSqData[tid] += sumSqData[tid + 64]; - } - __syncthreads(); - } - if (tid < 32) { - warpReduce(sumData, tid); - warpReduce(sumSqData, tid); - } - - if (tid == 0) { - channelSums[blockIdx.x] = sumData[0]; - channelSumsq[blockIdx.x] = sumSqData[0]; + m_val = BlockReduce(m_storage).Sum(m_val); + v_val = BlockReduce(v_storage).Sum(v_val); + if (threadIdx.x == 0) { + sum[c] = m_val; + sumsq[c] = v_val; } } + } // namespace template <> -bool ChannelStatsOp::RunOnDevice() { - const auto& X = Input(INPUT); - CAFFE_ENFORCE(X.dim() >= 3 && X.dim() <= 5); - const int N = X.dim32(0); - const int C = X.dim32(1); - const int H = X.dim32(2); - const int W = X.dim() > 3 ? X.dim32(3) : 1; - const int D = X.dim() > 4 ? X.dim32(4) : 1; +template <> +bool ChannelStatsOp::ComputeChannelStatsNCHW( + const int N, + const int C, + const int HxW, + const float* X, + float* sum, + float* sumsq) { + DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1( + HxW, + ChannelStatsNCHWCUDAKernel, + float, + C, + context_.cuda_stream(), + N, + C, + HxW, + X, + sum, + sumsq); + return true; +} - const auto X_arr = X.data(); - const auto valsPerChannel = H * W * D; - - const auto numBlocksPerChannel = CAFFE_GET_BLOCKS(valsPerChannel); - const auto numBlocksTotal = numBlocksPerChannel * N * C; - - ReinitializeTensor( - &sumScratch_, {numBlocksTotal}, at::dtype().device(CUDA)); - ReinitializeTensor( - &sumsqScratch_, {numBlocksTotal}, at::dtype().device(CUDA)); - - auto sum = Output(SUM, {C}, at::dtype()); - auto sumsq = Output(SUMSQ, {C}, at::dtype()); - - ChannelStatsBlockKernel - <<>>( - N, - C, - valsPerChannel, - X_arr, - sumScratch_.mutable_data(), - sumsqScratch_.mutable_data()); - - ChannelStatsFinalSumsKernel +template <> +template <> +bool ChannelStatsOp::ComputeChannelStatsNHWC( + const int N, + const int C, + const int HxW, + const float* X, + float* sum, + float* sumsq) { + ChannelStatsNHWCCUDAKernel <<>>( - N, - C, - numBlocksPerChannel, - sumScratch_.data(), - sumsqScratch_.data(), - sum->template mutable_data(), - sumsq->template mutable_data()); - + N, C, HxW, X, sum, sumsq); return true; } diff --git a/caffe2/operators/channel_stats_op.h b/caffe2/operators/channel_stats_op.h index 9bdd3d36db4..17ff43b77ed 100644 --- a/caffe2/operators/channel_stats_op.h +++ b/caffe2/operators/channel_stats_op.h @@ -1,5 +1,7 @@ -#ifndef CAFFE2_OPERATORS_CHANNEL_STATS_OP_H -#define CAFFE2_OPERATORS_CHANNEL_STATS_OP_H +#ifndef CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_ +#define CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_ + +#include #include "caffe2/core/context.h" #include "caffe2/core/operator.h" @@ -8,26 +10,51 @@ namespace caffe2 { template -class ChannelStatsOp : public Operator { +class ChannelStatsOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; + template explicit ChannelStatsOp(Args&&... args) - : Operator(std::forward(args)...) {} - ~ChannelStatsOp() {} - - bool RunOnDevice() override { - return true; + : Operator(std::forward(args)...), + order_(StringToStorageOrder( + this->template GetSingleArgument("order", "NCHW"))) { + CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN); } - protected: - INPUT_TAGS(INPUT); - OUTPUT_TAGS(SUM, SUMSQ); + bool RunOnDevice() override { + return DispatchHelper>::call(this, Input(0)); + } - Tensor sumScratch_; - Tensor sumsqScratch_; + template + bool DoRunWithType() { + const auto& X = Input(0); + const int ndim = X.dim(); + const int N = X.dim32(0); + const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1); + const int HxW = X.numel() / (N * C); + auto* sum = Output(0, {C}, at::dtype()); + auto* sumsq = Output(1, {C}, at::dtype()); + const T* X_data = X.template data(); + T* sum_data = sum->template mutable_data(); + T* sumsq_data = sumsq->template mutable_data(); + return order_ == StorageOrder::NCHW + ? ComputeChannelStatsNCHW(N, C, HxW, X_data, sum_data, sumsq_data) + : ComputeChannelStatsNHWC(N, C, HxW, X_data, sum_data, sumsq_data); + } + + private: + template + bool + ComputeChannelStatsNCHW(int N, int C, int HxW, const T* X, T* sum, T* sumsq); + + template + bool + ComputeChannelStatsNHWC(int N, int C, int HxW, const T* X, T* sum, T* sumsq); + + const StorageOrder order_; }; } // namespace caffe2 -#endif // CAFFE2_OPERATORS_CHANNEL_STATS_OP_H +#endif // CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_ diff --git a/caffe2/python/operator_test/channel_stats_op_test.py b/caffe2/python/operator_test/channel_stats_op_test.py index f1daddee772..d793b5f8e43 100644 --- a/caffe2/python/operator_test/channel_stats_op_test.py +++ b/caffe2/python/operator_test/channel_stats_op_test.py @@ -1,42 +1,85 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from __future__ import unicode_literals from caffe2.python import core import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial + from hypothesis import assume, given import hypothesis.strategies as st import numpy as np + import unittest -class TestChannelStats(serial.SerializedTestCase): - @serial.given( - size=st.integers(7, 10), - inputChannels=st.integers(1, 10), - batchSize=st.integers(1, 3), - **hu.gcs - ) - def testChannelStats(self, size, inputChannels, batchSize, gc, dc): +class TestChannelStatsOp(serial.SerializedTestCase): + def channel_stats_nchw_ref(self, X): + dims = X.shape + N = dims[0] + C = dims[1] + X = X.reshape(N, C, -1) + sum1 = np.sum(X, axis=(0, 2), keepdims=False) + sum2 = np.sum(X**2, axis=(0, 2), keepdims=False) + return (sum1, sum2) + def channel_stats_nhwc_ref(self, X): + dims = X.shape + N = dims[0] + C = dims[-1] + X = X.reshape(N, -1, C) + sum1 = np.sum(X, axis=(0, 1), keepdims=False) + sum2 = np.sum(X**2, axis=(0, 1), keepdims=False) + return (sum1, sum2) + + @serial.given( + N=st.integers(1, 5), C=st.integers(1, 10), H=st.integers(1, 12), + W=st.integers(1, 12), order=st.sampled_from(["NCHW", "NHWC"]), **hu.gcs) + def test_channel_stats_2d(self, N, C, H, W, order, gc, dc): op = core.CreateOperator( "ChannelStats", ["X"], ["sum", "sumsq"], + order=order, ) - def referenceChannelStatsTest(X): - sums = np.sum(X, axis=(0, 2, 3), keepdims=False) - sumsq = np.zeros(inputChannels) - sumsq = np.sum(X**2, axis=(0, 2, 3), keepdims=False) - return sums, sumsq + def ref_op(X): + if order == "NCHW": + return self.channel_stats_nchw_ref(X) + else: + return self.channel_stats_nhwc_ref(X) - X = np.random.rand(batchSize, inputChannels, size, size)\ - .astype(np.float32) - 0.5 - self.assertReferenceChecks(gc, op, [X], referenceChannelStatsTest) + X = np.random.randn(N, C, H, W).astype(np.float32) + if order == "NHWC": + X = np.transpose(X, [0, 2, 3, 1]) + self.assertReferenceChecks(gc, op, [X], reference=ref_op) + self.assertDeviceChecks(dc, op, [X], [0, 1]) + + @serial.given( + N=st.integers(1, 5), C=st.integers(1, 10), D=st.integers(1, 6), + H=st.integers(1, 6), W=st.integers(1, 6), + order=st.sampled_from(["NCHW", "NHWC"]), **hu.gcs) + def test_channel_stats_3d(self, N, C, D, H, W, order, gc, dc): + op = core.CreateOperator( + "ChannelStats", + ["X"], + ["sum", "sumsq"], + order=order, + ) + + def ref_op(X): + if order == "NCHW": + return self.channel_stats_nchw_ref(X) + else: + return self.channel_stats_nhwc_ref(X) + + X = np.random.randn(N, C, D, H, W).astype(np.float32) + if order == "NHWC": + X = np.transpose(X, [0, 2, 3, 4, 1]) + + self.assertReferenceChecks(gc, op, [X], reference=ref_op) + self.assertDeviceChecks(dc, op, [X], [0, 1]) if __name__ == "__main__": unittest.main() diff --git a/caffe2/python/operator_test/group_norm_op_test.py b/caffe2/python/operator_test/group_norm_op_test.py index febf05136e4..5507cd7d55a 100644 --- a/caffe2/python/operator_test/group_norm_op_test.py +++ b/caffe2/python/operator_test/group_norm_op_test.py @@ -10,6 +10,8 @@ from hypothesis import given import hypothesis.strategies as st import numpy as np +import unittest + class TestGroupNormOp(serial.SerializedTestCase): def group_norm_nchw_ref(self, X, gamma, beta, group, epsilon): @@ -144,3 +146,7 @@ class TestGroupNormOp(serial.SerializedTestCase): inputs = [X, gamma, beta] for i in range(len(inputs)): self.assertGradientChecks(gc, op, inputs, i, [0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/caffe2/python/serialized_test/SerializedTestCoverage.md b/caffe2/python/serialized_test/SerializedTestCoverage.md index ae31d43fc6f..4db6c370471 100644 --- a/caffe2/python/serialized_test/SerializedTestCoverage.md +++ b/caffe2/python/serialized_test/SerializedTestCoverage.md @@ -1,11 +1,11 @@ # Serialized Test Coverage Report This is an automatically generated file. Please see `caffe2/python/serialized_test/README.md` for details. In the case of merge conflicts, please rebase and regenerate. ## Summary -Serialized tests have covered 217/684 (31.7%) operators +Serialized tests have covered 219/688 (31.8%) operators ## Not covered operators
-There are 467 not covered operators +There are 469 not covered operators * APMeter * ATen @@ -17,6 +17,7 @@ Serialized tests have covered 217/684 (31.7%) operators * Adam * Add * AddGradient +* AdjustBatch * Alias * Allgather * Allreduce @@ -96,6 +97,7 @@ Serialized tests have covered 217/684 (31.7%) operators * CubeGradient * DBExists * DataCouple +* DenseVectorToIdList * DepthConcat * DepthSplit * DequeueBlobs @@ -478,7 +480,7 @@ Serialized tests have covered 217/684 (31.7%) operators ## Covered operators
-There are 217 covered operators +There are 219 covered operators * Acos * AcosGradient @@ -543,6 +545,8 @@ Serialized tests have covered 217/684 (31.7%) operators * ElementwiseLinearGradient * Elu * EluGradient +* Erf +* ErfGradient * Expand * ExpandGradient * FC @@ -702,7 +706,7 @@ Serialized tests have covered 217/684 (31.7%) operators ## Excluded from coverage statistics ### Schemaless operators
-There are 21 schemaless operators +There are 22 schemaless operators * C10Add_DontUseThisOpYet * C10AveragedLoss_DontUseThisOpYet @@ -718,6 +722,7 @@ Serialized tests have covered 217/684 (31.7%) operators * C10GivenTensorFill_DontUseThisOpYet * C10GivenTensorInt64Fill_DontUseThisOpYet * C10GivenTensorIntFill_DontUseThisOpYet +* C10LayerNorm_DontUseThisOpYet * C10Mul_DontUseThisOpYet * C10Relu_DontUseThisOpYet * C10SigmoidCrossEntropyWithLogits_DontUseThisOpYet diff --git a/caffe2/python/serialized_test/data/operator_test/channel_stats_op_test.test_channel_stats_2d.zip b/caffe2/python/serialized_test/data/operator_test/channel_stats_op_test.test_channel_stats_2d.zip new file mode 100644 index 00000000000..9e1936974f1 Binary files /dev/null and b/caffe2/python/serialized_test/data/operator_test/channel_stats_op_test.test_channel_stats_2d.zip differ diff --git a/caffe2/python/serialized_test/data/operator_test/channel_stats_op_test.test_channel_stats_3d.zip b/caffe2/python/serialized_test/data/operator_test/channel_stats_op_test.test_channel_stats_3d.zip new file mode 100644 index 00000000000..a8a237af747 Binary files /dev/null and b/caffe2/python/serialized_test/data/operator_test/channel_stats_op_test.test_channel_stats_3d.zip differ