mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Optimize channel_stats_op (#16243)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16243 Optimize channel_stats_op and add NHWC impl Reviewed By: takatosp1 Differential Revision: D13775515 fbshipit-source-id: decb889e646f5316d4afefdf9f9b6bc6343613cd
This commit is contained in:
parent
99f1465c35
commit
54b33503ec
|
|
@ -1,39 +1,51 @@
|
|||
#include "caffe2/operators/channel_stats_op.h"
|
||||
|
||||
#include "caffe2/utils/eigen_utils.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <>
|
||||
bool ChannelStatsOp<CPUContext>::RunOnDevice() {
|
||||
const auto& X = Input(INPUT);
|
||||
CAFFE_ENFORCE(X.dim() >= 3 && X.dim() <= 5);
|
||||
const int N = X.dim32(0);
|
||||
const int C = X.dim32(1);
|
||||
const int H = X.dim32(2);
|
||||
const int W = X.dim() > 3 ? X.dim32(3) : 1;
|
||||
const int D = X.dim() > 4 ? X.dim32(4) : 1;
|
||||
|
||||
const int sampleSize = H * W * D;
|
||||
|
||||
Output(SUM)->Resize(C);
|
||||
Output(SUMSQ)->Resize(C);
|
||||
EigenVectorArrayMap<float> sum(
|
||||
Output(SUM)->template mutable_data<float>(), C);
|
||||
EigenVectorArrayMap<float> sumsq(
|
||||
Output(SUMSQ)->template mutable_data<float>(), C);
|
||||
|
||||
sum.setZero();
|
||||
sumsq.setZero();
|
||||
ConstEigenArrayMap<float> X_arr(X.data<float>(), sampleSize, N * C);
|
||||
auto index = 0;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int c = 0; c < C; ++c) {
|
||||
sum(c) += X_arr.col(index).sum();
|
||||
sumsq(c) += X_arr.col(index).matrix().squaredNorm();
|
||||
index++;
|
||||
template <>
|
||||
bool ChannelStatsOp<CPUContext>::ComputeChannelStatsNCHW<float>(
|
||||
const int N,
|
||||
const int C,
|
||||
const int HxW,
|
||||
const float* X,
|
||||
float* sum,
|
||||
float* sumsq) {
|
||||
ConstEigenArrayMap<float> X_arr(X, HxW, N * C);
|
||||
for (int i = 0; i < C; ++i) {
|
||||
sum[i] = X_arr.col(i).sum();
|
||||
sumsq[i] = X_arr.col(i).square().sum();
|
||||
}
|
||||
for (int i = 1; i < N; ++i) {
|
||||
for (int j = 0; j < C; ++j) {
|
||||
const int c = i * C + j;
|
||||
sum[j] += X_arr.col(c).sum();
|
||||
sumsq[j] += X_arr.col(c).square().sum();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
template <>
|
||||
bool ChannelStatsOp<CPUContext>::ComputeChannelStatsNHWC<float>(
|
||||
const int N,
|
||||
const int C,
|
||||
const int HxW,
|
||||
const float* X,
|
||||
float* sum,
|
||||
float* sumsq) {
|
||||
ConstEigenArrayMap<float> X_arr(X, C, N * HxW);
|
||||
EigenVectorArrayMap<float> sum_arr(sum, C);
|
||||
EigenVectorArrayMap<float> sumsq_arr(sumsq, C);
|
||||
sum_arr = X_arr.col(0);
|
||||
sumsq_arr = X_arr.col(0).square();
|
||||
for (int i = 1; i < N * HxW; ++i) {
|
||||
sum_arr += X_arr.col(i);
|
||||
sumsq_arr += X_arr.col(i).square();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -49,7 +61,6 @@ reduced across multiple batches and used to obtain the mean and variance across
|
|||
the full set of batches. Using the new mean and variance as input to SpatialBN
|
||||
has the effect of changing the batch size over which SpatialBN is applied.
|
||||
)DOC")
|
||||
|
||||
.Input(0, "X", "The input 4-dimensional tensor of shape NCHW")
|
||||
.Output(
|
||||
0,
|
||||
|
|
@ -61,5 +72,7 @@ has the effect of changing the batch size over which SpatialBN is applied.
|
|||
"sumsq",
|
||||
"The output 1-dimensional tensor of size C containing the sum of "
|
||||
"elements squared per channel.");
|
||||
|
||||
SHOULD_NOT_DO_GRADIENT(ChannelStats);
|
||||
|
||||
} // namespace caffe2
|
||||
|
|
|
|||
|
|
@ -1,192 +1,117 @@
|
|||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/operators/channel_stats_op.h"
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/utils/math/reduce.cuh"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
|
||||
// based on "Optimizing Parallel Reduction in CUDA" by Mark Harris
|
||||
|
||||
// note - volatile keyword is needed to allow doing a warp reduction without
|
||||
// synchronization on recent architectures
|
||||
template <unsigned int blockSize>
|
||||
__device__ void warpReduce(volatile float* sdata, unsigned int tid) {
|
||||
// note - the if statements are "free" as they are resolved at compile time
|
||||
if (blockSize >= 64)
|
||||
sdata[tid] += sdata[tid + 32];
|
||||
if (blockSize >= 32)
|
||||
sdata[tid] += sdata[tid + 16];
|
||||
if (blockSize >= 16)
|
||||
sdata[tid] += sdata[tid + 8];
|
||||
if (blockSize >= 8)
|
||||
sdata[tid] += sdata[tid + 4];
|
||||
if (blockSize >= 4)
|
||||
sdata[tid] += sdata[tid + 2];
|
||||
if (blockSize >= 2)
|
||||
sdata[tid] += sdata[tid + 1];
|
||||
}
|
||||
|
||||
template <unsigned int blockSize>
|
||||
__global__ void ChannelStatsBlockKernel(
|
||||
int N,
|
||||
int C,
|
||||
int valsPerChannel,
|
||||
const float* inputData,
|
||||
float* sums,
|
||||
float* sumsq) {
|
||||
__shared__ float sumData[blockSize];
|
||||
__shared__ float sumSqData[blockSize];
|
||||
|
||||
auto tid = threadIdx.x;
|
||||
auto numBlocksPerChannel = (valsPerChannel + blockSize - 1) / blockSize;
|
||||
auto localBlockIndex = blockIdx.x % numBlocksPerChannel;
|
||||
auto inputIndex = (blockIdx.x / numBlocksPerChannel) * valsPerChannel +
|
||||
localBlockIndex * blockSize + tid;
|
||||
|
||||
sumData[tid] = 0;
|
||||
sumSqData[tid] = 0;
|
||||
|
||||
if (localBlockIndex * blockSize + tid < valsPerChannel) {
|
||||
sumData[tid] += inputData[inputIndex];
|
||||
sumSqData[tid] += inputData[inputIndex] * inputData[inputIndex];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (blockSize >= 512) {
|
||||
if (tid < 256) {
|
||||
sumData[tid] += sumData[tid + 256];
|
||||
sumSqData[tid] += sumSqData[tid + 256];
|
||||
template <typename T, int kBlockDimX, int kBlockDimY>
|
||||
__global__ void ChannelStatsNCHWCUDAKernel(
|
||||
const int N,
|
||||
const int C,
|
||||
const int HxW,
|
||||
const T* X,
|
||||
T* sum,
|
||||
T* sumsq) {
|
||||
__shared__
|
||||
typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage m_storage;
|
||||
__shared__
|
||||
typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage v_storage;
|
||||
const int c = blockIdx.x;
|
||||
T m_val = 0;
|
||||
T v_val = 0;
|
||||
for (int n = threadIdx.x; n < N; n += blockDim.x) {
|
||||
for (int hw = threadIdx.y; hw < HxW; hw += blockDim.y) {
|
||||
const int index = (n * C + c) * HxW + hw;
|
||||
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
|
||||
m_val += __ldg(X + index);
|
||||
v_val += __ldg(X + index) * __ldg(X + index);
|
||||
#else
|
||||
m_val += X[index];
|
||||
v_val += X[index] * X[index];
|
||||
#endif
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (blockSize >= 256) {
|
||||
if (tid < 128) {
|
||||
sumData[tid] += sumData[tid + 128];
|
||||
sumSqData[tid] += sumSqData[tid + 128];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (blockSize >= 128) {
|
||||
if (tid < 64) {
|
||||
sumData[tid] += sumData[tid + 64];
|
||||
sumSqData[tid] += sumSqData[tid + 64];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (tid < 32) {
|
||||
warpReduce<blockSize>(sumData, tid);
|
||||
warpReduce<blockSize>(sumSqData, tid);
|
||||
}
|
||||
|
||||
// output block data sorted by C to simplify second reduction
|
||||
if (tid == 0) {
|
||||
auto n = blockIdx.x / numBlocksPerChannel / C;
|
||||
auto c = (blockIdx.x / numBlocksPerChannel) % C;
|
||||
auto outputIndex = (c * N + n) * numBlocksPerChannel + localBlockIndex;
|
||||
sums[outputIndex] = sumData[0];
|
||||
sumsq[outputIndex] = sumSqData[0];
|
||||
m_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(m_storage).Sum(m_val);
|
||||
v_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(v_storage).Sum(v_val);
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
sum[c] = m_val;
|
||||
sumsq[c] = v_val;
|
||||
}
|
||||
}
|
||||
|
||||
template <unsigned int blockSize>
|
||||
__global__ void ChannelStatsFinalSumsKernel(
|
||||
int N,
|
||||
int C,
|
||||
int numSumsPerChannel,
|
||||
const float* sumsScratch,
|
||||
const float* sumsqScratch,
|
||||
float* channelSums,
|
||||
float* channelSumsq) {
|
||||
__shared__ float sumData[blockSize];
|
||||
__shared__ float sumSqData[blockSize];
|
||||
|
||||
auto tid = threadIdx.x;
|
||||
auto inputIndex = blockIdx.x * N * numSumsPerChannel + tid;
|
||||
sumData[tid] = 0;
|
||||
sumSqData[tid] = 0;
|
||||
for (auto i = inputIndex; i < (blockIdx.x + 1) * N * numSumsPerChannel;
|
||||
i += blockSize) {
|
||||
sumData[tid] += sumsScratch[i];
|
||||
sumSqData[tid] += sumsqScratch[i];
|
||||
template <typename T>
|
||||
__global__ void ChannelStatsNHWCCUDAKernel(
|
||||
const int N,
|
||||
const int C,
|
||||
const int HxW,
|
||||
const T* X,
|
||||
T* sum,
|
||||
T* sumsq) {
|
||||
__shared__ typename BlockReduce<T>::TempStorage m_storage;
|
||||
__shared__ typename BlockReduce<T>::TempStorage v_storage;
|
||||
const int inner_size = N * HxW;
|
||||
const int c = blockIdx.x;
|
||||
T m_val = 0;
|
||||
T v_val = 0;
|
||||
for (int i = threadIdx.x; i < inner_size; i += blockDim.x) {
|
||||
const int index = i * C + c;
|
||||
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
|
||||
m_val += __ldg(X + index);
|
||||
v_val += __ldg(X + index) * __ldg(X + index);
|
||||
#else
|
||||
m_val += X[index];
|
||||
v_val += X[index] * X[index];
|
||||
#endif
|
||||
}
|
||||
__syncthreads();
|
||||
if (blockSize >= 512) {
|
||||
if (tid < 256) {
|
||||
sumData[tid] += sumData[tid + 256];
|
||||
sumSqData[tid] += sumSqData[tid + 256];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (blockSize >= 256) {
|
||||
if (tid < 128) {
|
||||
sumData[tid] += sumData[tid + 128];
|
||||
sumSqData[tid] += sumSqData[tid + 128];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (blockSize >= 128) {
|
||||
if (tid < 64) {
|
||||
sumData[tid] += sumData[tid + 64];
|
||||
sumSqData[tid] += sumSqData[tid + 64];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (tid < 32) {
|
||||
warpReduce<blockSize>(sumData, tid);
|
||||
warpReduce<blockSize>(sumSqData, tid);
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
channelSums[blockIdx.x] = sumData[0];
|
||||
channelSumsq[blockIdx.x] = sumSqData[0];
|
||||
m_val = BlockReduce<T>(m_storage).Sum(m_val);
|
||||
v_val = BlockReduce<T>(v_storage).Sum(v_val);
|
||||
if (threadIdx.x == 0) {
|
||||
sum[c] = m_val;
|
||||
sumsq[c] = v_val;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <>
|
||||
bool ChannelStatsOp<CUDAContext>::RunOnDevice() {
|
||||
const auto& X = Input(INPUT);
|
||||
CAFFE_ENFORCE(X.dim() >= 3 && X.dim() <= 5);
|
||||
const int N = X.dim32(0);
|
||||
const int C = X.dim32(1);
|
||||
const int H = X.dim32(2);
|
||||
const int W = X.dim() > 3 ? X.dim32(3) : 1;
|
||||
const int D = X.dim() > 4 ? X.dim32(4) : 1;
|
||||
template <>
|
||||
bool ChannelStatsOp<CUDAContext>::ComputeChannelStatsNCHW<float>(
|
||||
const int N,
|
||||
const int C,
|
||||
const int HxW,
|
||||
const float* X,
|
||||
float* sum,
|
||||
float* sumsq) {
|
||||
DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1(
|
||||
HxW,
|
||||
ChannelStatsNCHWCUDAKernel,
|
||||
float,
|
||||
C,
|
||||
context_.cuda_stream(),
|
||||
N,
|
||||
C,
|
||||
HxW,
|
||||
X,
|
||||
sum,
|
||||
sumsq);
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto X_arr = X.data<float>();
|
||||
const auto valsPerChannel = H * W * D;
|
||||
|
||||
const auto numBlocksPerChannel = CAFFE_GET_BLOCKS(valsPerChannel);
|
||||
const auto numBlocksTotal = numBlocksPerChannel * N * C;
|
||||
|
||||
ReinitializeTensor(
|
||||
&sumScratch_, {numBlocksTotal}, at::dtype<float>().device(CUDA));
|
||||
ReinitializeTensor(
|
||||
&sumsqScratch_, {numBlocksTotal}, at::dtype<float>().device(CUDA));
|
||||
|
||||
auto sum = Output(SUM, {C}, at::dtype<float>());
|
||||
auto sumsq = Output(SUMSQ, {C}, at::dtype<float>());
|
||||
|
||||
ChannelStatsBlockKernel<CAFFE_CUDA_NUM_THREADS>
|
||||
<<<numBlocksTotal, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
|
||||
N,
|
||||
C,
|
||||
valsPerChannel,
|
||||
X_arr,
|
||||
sumScratch_.mutable_data<float>(),
|
||||
sumsqScratch_.mutable_data<float>());
|
||||
|
||||
ChannelStatsFinalSumsKernel<CAFFE_CUDA_NUM_THREADS>
|
||||
template <>
|
||||
template <>
|
||||
bool ChannelStatsOp<CUDAContext>::ComputeChannelStatsNHWC<float>(
|
||||
const int N,
|
||||
const int C,
|
||||
const int HxW,
|
||||
const float* X,
|
||||
float* sum,
|
||||
float* sumsq) {
|
||||
ChannelStatsNHWCCUDAKernel<float>
|
||||
<<<C, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
|
||||
N,
|
||||
C,
|
||||
numBlocksPerChannel,
|
||||
sumScratch_.data<float>(),
|
||||
sumsqScratch_.data<float>(),
|
||||
sum->template mutable_data<float>(),
|
||||
sumsq->template mutable_data<float>());
|
||||
|
||||
N, C, HxW, X, sum, sumsq);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
#ifndef CAFFE2_OPERATORS_CHANNEL_STATS_OP_H
|
||||
#define CAFFE2_OPERATORS_CHANNEL_STATS_OP_H
|
||||
#ifndef CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_
|
||||
#define CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
|
@ -8,26 +10,51 @@
|
|||
namespace caffe2 {
|
||||
|
||||
template <class Context>
|
||||
class ChannelStatsOp : public Operator<Context> {
|
||||
class ChannelStatsOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
template <class... Args>
|
||||
explicit ChannelStatsOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
~ChannelStatsOp() {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
return true;
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
order_(StringToStorageOrder(
|
||||
this->template GetSingleArgument<std::string>("order", "NCHW"))) {
|
||||
CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
|
||||
}
|
||||
|
||||
protected:
|
||||
INPUT_TAGS(INPUT);
|
||||
OUTPUT_TAGS(SUM, SUMSQ);
|
||||
bool RunOnDevice() override {
|
||||
return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
|
||||
}
|
||||
|
||||
Tensor sumScratch_;
|
||||
Tensor sumsqScratch_;
|
||||
template <typename T>
|
||||
bool DoRunWithType() {
|
||||
const auto& X = Input(0);
|
||||
const int ndim = X.dim();
|
||||
const int N = X.dim32(0);
|
||||
const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
|
||||
const int HxW = X.numel() / (N * C);
|
||||
auto* sum = Output(0, {C}, at::dtype<T>());
|
||||
auto* sumsq = Output(1, {C}, at::dtype<T>());
|
||||
const T* X_data = X.template data<T>();
|
||||
T* sum_data = sum->template mutable_data<T>();
|
||||
T* sumsq_data = sumsq->template mutable_data<T>();
|
||||
return order_ == StorageOrder::NCHW
|
||||
? ComputeChannelStatsNCHW<T>(N, C, HxW, X_data, sum_data, sumsq_data)
|
||||
: ComputeChannelStatsNHWC<T>(N, C, HxW, X_data, sum_data, sumsq_data);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool
|
||||
ComputeChannelStatsNCHW(int N, int C, int HxW, const T* X, T* sum, T* sumsq);
|
||||
|
||||
template <typename T>
|
||||
bool
|
||||
ComputeChannelStatsNHWC(int N, int C, int HxW, const T* X, T* sum, T* sumsq);
|
||||
|
||||
const StorageOrder order_;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_CHANNEL_STATS_OP_H
|
||||
#endif // CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_
|
||||
|
|
|
|||
|
|
@ -1,42 +1,85 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from caffe2.python import core
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
import caffe2.python.serialized_test.serialized_test_util as serial
|
||||
|
||||
from hypothesis import assume, given
|
||||
import hypothesis.strategies as st
|
||||
import numpy as np
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestChannelStats(serial.SerializedTestCase):
|
||||
@serial.given(
|
||||
size=st.integers(7, 10),
|
||||
inputChannels=st.integers(1, 10),
|
||||
batchSize=st.integers(1, 3),
|
||||
**hu.gcs
|
||||
)
|
||||
def testChannelStats(self, size, inputChannels, batchSize, gc, dc):
|
||||
class TestChannelStatsOp(serial.SerializedTestCase):
|
||||
def channel_stats_nchw_ref(self, X):
|
||||
dims = X.shape
|
||||
N = dims[0]
|
||||
C = dims[1]
|
||||
X = X.reshape(N, C, -1)
|
||||
sum1 = np.sum(X, axis=(0, 2), keepdims=False)
|
||||
sum2 = np.sum(X**2, axis=(0, 2), keepdims=False)
|
||||
return (sum1, sum2)
|
||||
|
||||
def channel_stats_nhwc_ref(self, X):
|
||||
dims = X.shape
|
||||
N = dims[0]
|
||||
C = dims[-1]
|
||||
X = X.reshape(N, -1, C)
|
||||
sum1 = np.sum(X, axis=(0, 1), keepdims=False)
|
||||
sum2 = np.sum(X**2, axis=(0, 1), keepdims=False)
|
||||
return (sum1, sum2)
|
||||
|
||||
@serial.given(
|
||||
N=st.integers(1, 5), C=st.integers(1, 10), H=st.integers(1, 12),
|
||||
W=st.integers(1, 12), order=st.sampled_from(["NCHW", "NHWC"]), **hu.gcs)
|
||||
def test_channel_stats_2d(self, N, C, H, W, order, gc, dc):
|
||||
op = core.CreateOperator(
|
||||
"ChannelStats",
|
||||
["X"],
|
||||
["sum", "sumsq"],
|
||||
order=order,
|
||||
)
|
||||
|
||||
def referenceChannelStatsTest(X):
|
||||
sums = np.sum(X, axis=(0, 2, 3), keepdims=False)
|
||||
sumsq = np.zeros(inputChannels)
|
||||
sumsq = np.sum(X**2, axis=(0, 2, 3), keepdims=False)
|
||||
return sums, sumsq
|
||||
def ref_op(X):
|
||||
if order == "NCHW":
|
||||
return self.channel_stats_nchw_ref(X)
|
||||
else:
|
||||
return self.channel_stats_nhwc_ref(X)
|
||||
|
||||
X = np.random.rand(batchSize, inputChannels, size, size)\
|
||||
.astype(np.float32) - 0.5
|
||||
self.assertReferenceChecks(gc, op, [X], referenceChannelStatsTest)
|
||||
X = np.random.randn(N, C, H, W).astype(np.float32)
|
||||
if order == "NHWC":
|
||||
X = np.transpose(X, [0, 2, 3, 1])
|
||||
|
||||
self.assertReferenceChecks(gc, op, [X], reference=ref_op)
|
||||
self.assertDeviceChecks(dc, op, [X], [0, 1])
|
||||
|
||||
@serial.given(
|
||||
N=st.integers(1, 5), C=st.integers(1, 10), D=st.integers(1, 6),
|
||||
H=st.integers(1, 6), W=st.integers(1, 6),
|
||||
order=st.sampled_from(["NCHW", "NHWC"]), **hu.gcs)
|
||||
def test_channel_stats_3d(self, N, C, D, H, W, order, gc, dc):
|
||||
op = core.CreateOperator(
|
||||
"ChannelStats",
|
||||
["X"],
|
||||
["sum", "sumsq"],
|
||||
order=order,
|
||||
)
|
||||
|
||||
def ref_op(X):
|
||||
if order == "NCHW":
|
||||
return self.channel_stats_nchw_ref(X)
|
||||
else:
|
||||
return self.channel_stats_nhwc_ref(X)
|
||||
|
||||
X = np.random.randn(N, C, D, H, W).astype(np.float32)
|
||||
if order == "NHWC":
|
||||
X = np.transpose(X, [0, 2, 3, 4, 1])
|
||||
|
||||
self.assertReferenceChecks(gc, op, [X], reference=ref_op)
|
||||
self.assertDeviceChecks(dc, op, [X], [0, 1])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from hypothesis import given
|
|||
import hypothesis.strategies as st
|
||||
import numpy as np
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class TestGroupNormOp(serial.SerializedTestCase):
|
||||
def group_norm_nchw_ref(self, X, gamma, beta, group, epsilon):
|
||||
|
|
@ -144,3 +146,7 @@ class TestGroupNormOp(serial.SerializedTestCase):
|
|||
inputs = [X, gamma, beta]
|
||||
for i in range(len(inputs)):
|
||||
self.assertGradientChecks(gc, op, inputs, i, [0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
# Serialized Test Coverage Report
|
||||
This is an automatically generated file. Please see `caffe2/python/serialized_test/README.md` for details. In the case of merge conflicts, please rebase and regenerate.
|
||||
## Summary
|
||||
Serialized tests have covered 217/684 (31.7%) operators
|
||||
Serialized tests have covered 219/688 (31.8%) operators
|
||||
|
||||
## Not covered operators
|
||||
<details>
|
||||
<summary>There are 467 not covered operators</summary>
|
||||
<summary>There are 469 not covered operators</summary>
|
||||
|
||||
* APMeter
|
||||
* ATen
|
||||
|
|
@ -17,6 +17,7 @@ Serialized tests have covered 217/684 (31.7%) operators
|
|||
* Adam
|
||||
* Add
|
||||
* AddGradient
|
||||
* AdjustBatch
|
||||
* Alias
|
||||
* Allgather
|
||||
* Allreduce
|
||||
|
|
@ -96,6 +97,7 @@ Serialized tests have covered 217/684 (31.7%) operators
|
|||
* CubeGradient
|
||||
* DBExists
|
||||
* DataCouple
|
||||
* DenseVectorToIdList
|
||||
* DepthConcat
|
||||
* DepthSplit
|
||||
* DequeueBlobs
|
||||
|
|
@ -478,7 +480,7 @@ Serialized tests have covered 217/684 (31.7%) operators
|
|||
|
||||
## Covered operators
|
||||
<details>
|
||||
<summary>There are 217 covered operators</summary>
|
||||
<summary>There are 219 covered operators</summary>
|
||||
|
||||
* Acos
|
||||
* AcosGradient
|
||||
|
|
@ -543,6 +545,8 @@ Serialized tests have covered 217/684 (31.7%) operators
|
|||
* ElementwiseLinearGradient
|
||||
* Elu
|
||||
* EluGradient
|
||||
* Erf
|
||||
* ErfGradient
|
||||
* Expand
|
||||
* ExpandGradient
|
||||
* FC
|
||||
|
|
@ -702,7 +706,7 @@ Serialized tests have covered 217/684 (31.7%) operators
|
|||
## Excluded from coverage statistics
|
||||
### Schemaless operators
|
||||
<details>
|
||||
<summary>There are 21 schemaless operators</summary>
|
||||
<summary>There are 22 schemaless operators</summary>
|
||||
|
||||
* C10Add_DontUseThisOpYet
|
||||
* C10AveragedLoss_DontUseThisOpYet
|
||||
|
|
@ -718,6 +722,7 @@ Serialized tests have covered 217/684 (31.7%) operators
|
|||
* C10GivenTensorFill_DontUseThisOpYet
|
||||
* C10GivenTensorInt64Fill_DontUseThisOpYet
|
||||
* C10GivenTensorIntFill_DontUseThisOpYet
|
||||
* C10LayerNorm_DontUseThisOpYet
|
||||
* C10Mul_DontUseThisOpYet
|
||||
* C10Relu_DontUseThisOpYet
|
||||
* C10SigmoidCrossEntropyWithLogits_DontUseThisOpYet
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue
Block a user