Change ConvPoolOp<Context>::SetOutputSize to ConvPoolOp<Context>::GetOutputSize (#17764)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17764

Original commit changeset: f1923fdca4a1

reverted int8 ops fixes the original runtime regression.
We'll ignore the memory regression since it is flaky, see D14228484

Reviewed By: dzhulgakov

Differential Revision: D13885233

fbshipit-source-id: ccbe4b94acb44b7b4cb3ae4d73e3f6091e1e1195
This commit is contained in:
Jerry Zhang 2019-03-07 18:31:33 -08:00 committed by Facebook Github Bot
parent cc7aec12fd
commit ac87488bd3
16 changed files with 202 additions and 49 deletions

View File

@ -196,8 +196,8 @@ class MaxPoolRTCOp final : public ConvPoolOpBase<CUDAContext> {
bool RunOnDeviceWithOrderNCHW() override {
auto& X = Input(0);
auto* Y = Output(0);
ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(1));
auto output_sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, X.dim32(1));
auto* Y = Output(0, output_sizes, at::dtype<float>());
if (input_dims_ != X.sizes()) {
// recompile

View File

@ -257,11 +257,10 @@ void computeOutputHW(
int* OH,
int* OW) {
Tensor input = caffe2::empty({1, 1, H, W}, at::dtype<float>().device(CPU));
Tensor output(CPU);
op->SetOutputSize(input, &output, 1);
CAFFE_ENFORCE_EQ(output.dim(), 4);
*OH = output.size(2);
*OW = output.size(3);
auto sizes = op->GetOutputSize(input, 1);
CAFFE_ENFORCE_EQ(sizes.size(), 4);
*OH = sizes[2];
*OW = sizes[3];
}
constexpr int computeMPSAlignOffset(int kernel, int pad) {

View File

@ -514,13 +514,13 @@ template <typename T_X, typename T_W, typename T_B, typename T_Y>
bool CudnnConvOp::DoRunWithType() {
auto& X = Input(INPUT);
auto& filter = Input(FILTER);
auto* Y = Output(0);
// Figure out the output shape
CAFFE_ENFORCE(X.dim() >= 3 && X.dim() <= 5);
CAFFE_ENFORCE(filter.dim() >= 3 && filter.dim() <= 5);
const int M = filter.dim32(0);
ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, M);
auto output_sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, M);
auto* Y = Output(0, output_sizes, at::dtype<T_Y>());
int N = 0, C = 0, H = 0, W = 0, D = 0, H_out = 0, W_out = 0, D_out = 0;
int group_offset_X = 0, group_offset_Y = 0;

View File

@ -208,7 +208,7 @@ class ConvPoolOpBase : public Operator<Context> {
return size;
}
// Sets the output size. The output channel is manually provided since
// Gets the output size. The output channel is manually provided since
// it may not be identical to the input channels.
// This function can be used in the forward functions to obtain the output
// sizes.
@ -216,7 +216,25 @@ class ConvPoolOpBase : public Operator<Context> {
// implementations that do not use first-class Tensor objects, such as the
// MKL operator. One can still call this function with dummy
// Tensor objects in order to obtain the sizes.
// TODO: passing sizes directly rather than Tensor
std::vector<int64_t> GetOutputSize(const Tensor& input, int output_channel) {
CAFFE_ENFORCE_GE(input.dim(), 2);
const int inner_size = input.size_from_dim(1);
CAFFE_ENFORCE_GT(inner_size, 0);
std::vector<int64_t> output_dims;
InferOutputSize64(
input.sizes(),
output_channel,
order_,
global_pooling_,
legacy_pad_,
dilation_,
stride_,
&kernel_,
&pads_,
&output_dims);
return output_dims;
}
void SetOutputSize(const Tensor& input, Tensor* output, int output_channel) {
const int inner_size = input.size_from_dim(1);
CAFFE_ENFORCE_GT(inner_size, 0);
@ -276,6 +294,45 @@ class ConvPoolOpBase : public Operator<Context> {
}
}
static void InferOutputSize64(
const at::IntList& input_dims,
const int output_channel,
const StorageOrder order,
const bool global_pooling,
const LegacyPadding legacy_pad,
const std::vector<int>& dilation,
const std::vector<int>& stride,
std::vector<int>* kernel,
std::vector<int>* pads,
std::vector<int64_t>* output_dims) {
CAFFE_ENFORCE_NE(order, StorageOrder::UNKNOWN);
const int ndim = input_dims.size() - 2;
output_dims->resize(ndim + 2);
output_dims->front() = input_dims.front();
if (order == StorageOrder::NCHW) {
output_dims->at(1) = output_channel;
} else {
output_dims->back() = output_channel;
}
const int offset = order == StorageOrder::NCHW ? 2 : 1;
if (global_pooling) {
std::copy_n(input_dims.cbegin() + offset, ndim, kernel->begin());
std::fill_n(output_dims->begin() + offset, ndim, 1LL);
} else {
for (int i = 0; i < ndim; ++i) {
ComputeSizeAndPad64(
input_dims[i + offset],
stride[i],
kernel->at(i),
dilation[i],
legacy_pad,
&pads->at(i),
&pads->at(i + ndim),
&output_dims->at(i + offset));
}
}
}
// ComputePads could be used in backward functions to figure out the padding
// values for the given input.
void ComputePads(const vector<int>& dims) {
@ -670,6 +727,85 @@ class ConvPoolOpBase : public Operator<Context> {
}
}
static inline void ComputeSizeAndPad64(
const int in_size,
const int stride,
const int kernel,
const int dilation,
LegacyPadding legacy_pad,
int* pad_head,
int* pad_tail,
int64_t* out_size) {
const int dkernel = dilation * (kernel - 1) + 1;
switch (legacy_pad) {
case LegacyPadding::NOTSET:
// We will just use the direct padding head and tail values, but we
// will verify that they are non-negative.
CAFFE_ENFORCE_GE(in_size + *pad_head + *pad_tail, dkernel);
*out_size = static_cast<int>(
static_cast<float>(in_size + *pad_head + *pad_tail - dkernel) /
stride +
1);
break;
case LegacyPadding::VALID:
*pad_head = 0;
*pad_tail = 0;
*out_size = (in_size - dkernel) / stride + 1;
break;
case LegacyPadding::SAME: {
CAFFE_ENFORCE(
1 == dilation, "Dilation not supported for legacy padding.");
int legacy_target_size = (in_size + stride - 1) / stride;
int pad_needed = (legacy_target_size - 1) * stride + kernel - in_size;
if (CAFFE2_PAD_HEAD_MORE) {
*pad_head = (pad_needed + 1) / 2;
} else {
*pad_head = pad_needed / 2;
}
*pad_tail = pad_needed - *pad_head;
*out_size = (in_size + pad_needed - dkernel) / stride + 1;
break;
}
case LegacyPadding::CAFFE_LEGACY_POOLING:
// This is in order to adapt Caffe's pooling padding case. In this case,
// we will only use pad_head and will compute pad_tail to match the
// old caffe pooling strategy. Also see caffe2_legacy.proto for more
// details.
CAFFE_ENFORCE_GE(*pad_head, 0);
// Here, notice that caffe casts UP while caffe2 casts DOWN for the
// output size computation.
*out_size = std::ceil(
static_cast<float>(in_size + *pad_head * 2 - kernel) / stride + 1);
// If we have padding, caffe also ensures that the last pooling starts
// strictly inside the image (instead of at the padding); otherwise clip
// the last.
if (*pad_head > 0 && (*out_size - 1) * stride >= in_size + *pad_head) {
--*out_size;
}
// Now, compare the output size with the standard Caffe2 output size.
// The
// caffe2 standard output size should always be no larger than the
// output
// size of caffe.
int standard_out_size = static_cast<int>(
static_cast<float>(in_size + *pad_head * 2 - kernel) / stride + 1);
CAFFE_ENFORCE_GE(
*out_size,
standard_out_size,
"This should never happen. If this happens, double check the logic "
"above.");
if (*out_size > standard_out_size) {
LOG(WARNING)
<< "You are hitting a case where Caffe's legacy padding calculation "
"is hit. This leads to inefficient and sometimes incorrect "
"results. We are keeping this behavior for backward compatibility"
", but you are strongly recommended to move away from it.";
}
*pad_tail = *pad_head + stride * (*out_size - standard_out_size);
break;
}
}
// Accessors for 2D conv params.
inline int pad_t() const {

View File

@ -288,7 +288,6 @@ class Depthwise3x3ConvOp final : public ConvPoolOpBase<CUDAContext> {
bool RunOnDeviceWithOrderNCHW() override {
const Tensor& X = Input(0);
auto& filter = Input(1);
Tensor* Y = Output(0);
const int N = X.dim32(0), C = X.dim32(1);
CAFFE_ENFORCE_EQ(X.dim(), filter.dim());
const int M = filter.dim32(0);
@ -300,7 +299,8 @@ class Depthwise3x3ConvOp final : public ConvPoolOpBase<CUDAContext> {
CAFFE_ENFORCE_EQ(this->kernel_w(), 3);
CAFFE_ENFORCE_EQ(this->kernel_h(), 3);
CAFFE_ENFORCE_EQ(this->stride_h(), this->stride_w());
ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, filter.dim32(0));
auto sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, filter.dim32(0));
Tensor* Y = Output(0, sizes, at::dtype<float>());
DepthwiseArgs args;
args.batch = X.dim32(0);
args.in_rows = X.dim32(2);
@ -458,7 +458,7 @@ class Depthwise3x3ConvGradientOp final : public ConvPoolOpBase<CUDAContext> {
M,
dY.dim32(2),
dY.dim32(3)));
auto* dbias = Output(BIAS_OR_INPUT_GRAD, {M}, at::dtype<float>());
CUDNN_ENFORCE(cudnnConvolutionBackwardBias(
cudnn_wrapper_.inline_cudnn_handle(),

View File

@ -207,7 +207,6 @@ template <typename T_X, typename T_W, typename T_B, typename MATH, typename T_Y>
bool MIOPENConvOp::DoRunWithType() {
auto& X = Input(INPUT);
auto& Weight = Input(FILTER);
auto* Y = Output(0);
// Figure out the output shape
CAFFE_ENFORCE(X.ndim() >= 3 && X.ndim() <= 5);
@ -216,7 +215,8 @@ bool MIOPENConvOp::DoRunWithType() {
"Conv op with MIOpen engine is supported only for 2D convolutions");
const int M = Weight.dim32(0);
ConvPoolOpBase<HIPContext>::SetOutputSize(X, Y, M);
auto sizes = ConvPoolOpBase<HIPContext>::GetOutputSize(X, M);
auto* Y = Output(0, sizes, at::dtype<T_Y>());
int N = X.dim32(0);
int C = X.dim32(1);

View File

@ -61,7 +61,6 @@ class MIOPENPoolOp : public ConvPoolOpBase<HIPContext> {
template <typename T, typename M>
bool DoRunWithType() {
auto& X = Input(0);
auto* Y = Output(0);
int N = 0, C = 0, H = 0, W = 0, D = 0;
int N_out = 0, C_out = 0, H_out = 0, W_out = 0;
CAFFE_ENFORCE(X.ndim() >= 4 && X.ndim() <= 5);
@ -69,7 +68,8 @@ class MIOPENPoolOp : public ConvPoolOpBase<HIPContext> {
C = X.dim32(1);
H = X.dim32(2);
W = X.ndim() > 3 ? X.dim32(3) : 1;
ConvPoolOpBase::SetOutputSize(X, Y, C);
auto sizes = ConvPoolOpBase::GetOutputSize(X, C);
auto* Y = Output(0, sizes, at::dtype<T>());
N_out = Y->dim32(0);
C_out = Y->dim32(1);

View File

@ -108,9 +108,10 @@ __global__ void MaxPoolBackward(
template <typename T>
bool MaxPoolWithIndexOp::DoRunWithType() {
auto& X = Input(0);
auto* Y = Output(0);
ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, X.dim32(1));
auto sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, X.dim32(1));
auto* Y = Output(0, sizes, at::dtype<T>());
int output_size = Y->numel();
auto* mask = Output(1, {output_size}, at::dtype<int>());

View File

@ -251,12 +251,13 @@ __global__ void PadImageGradientEdgeNHWC(
template <>
bool PadImageOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
auto& X = Input(0);
auto* Y = Output(0);
const int num = X.dim32(0);
const int channels = X.dim32(1);
const int height = X.dim32(2);
const int width = X.dim32(3);
ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, channels);
auto sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, channels);
auto* Y = Output(0, sizes, at::dtype<float>());
const int output_size = Y->numel();
const int padded_height = Y->dim32(2);
const int padded_width = Y->dim32(3);
@ -327,12 +328,13 @@ bool PadImageOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
template<>
bool PadImageOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
auto& X = Input(0);
auto* Y = Output(0);
const int num = X.dim32(0);
const int height = X.dim32(1);
const int width = X.dim32(2);
const int channels = X.dim32(3);
ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, channels);
auto sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, channels);
auto* Y = Output(0, sizes, at::dtype<float>());
const int output_size = Y->numel();
const int padded_height = Y->dim32(1);
const int padded_width = Y->dim32(2);
@ -403,7 +405,7 @@ bool PadImageOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
template<>
bool PadImageGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
auto& dY = Input(0);
auto* dX = Output(0, { dY.dim32(0),
dY.dim32(1),
dY.dim32(2) - pad_t() - pad_b(),
@ -483,7 +485,7 @@ bool PadImageGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
template<>
bool PadImageGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
auto& dY = Input(0);
auto* dX = Output(0, { dY.dim32(0),
dY.dim32(1) - pad_t() - pad_b(),
dY.dim32(2) - pad_l() - pad_r(),

View File

@ -100,11 +100,11 @@ class CuDNNPoolOp final : public ConvPoolOpBase<CUDAContext> {
template <typename T>
bool DoRunWithType() {
const auto& X = Input(0);
auto* Y = Output(0);
const int ndim = X.dim();
const int N = X.dim32(0);
const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
ConvPoolOpBase<CUDAContext>::SetOutputSize(X, Y, C);
auto sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, C);
auto* Y = Output(0, sizes, at::dtype<T>());
const T* X_data = X.template data<T>();
T* Y_data = Y->template mutable_data<T>();

View File

@ -102,8 +102,8 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
const Tensor& X = InputTensorCPU_(INPUT);
int N = X.dim32(0);
Tensor* Y = OutputTensorCPU_(0);
this->SetOutputSize(X, Y, filter.dim32(0));
auto sizes = this->GetOutputSize(X, filter.dim32(0));
Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<uint8_t>());
const int output_image_size = this->GetDimsSize(*Y);
if (N * output_image_size < FLAGS_caffe2_dnnlowp_acc16_m_threshold) {
@ -228,7 +228,6 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHW() {
const Tensor& X = InputTensorCPU_(INPUT);
auto& filter = InputTensorCPU_(FILTER);
Tensor* Y = OutputTensorCPU_(0);
const int N = X.dim32(0), C = X.dim32(1);
CAFFE_ENFORCE_EQ(X.ndim(), filter.ndim());
const int M = filter.dim32(0);
@ -246,7 +245,8 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHW() {
0,
"The number of output channels is not divisible by group.");
this->SetOutputSize(X, Y, filter.dim32(0));
auto sizes = this->GetOutputSize(X, filter.dim32(0));
Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<uint8_t>());
const vector<int> input_dims = GetDims(X);
const vector<int> output_dims = GetDims(*Y);
@ -618,14 +618,14 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWC() {
const Tensor& X = InputTensorCPU_(INPUT);
auto& filter = InputTensorCPU_(FILTER);
Tensor* Y = OutputTensorCPU_(0);
const int N = X.dim32(0), C = X.dim32(X.ndim() - 1);
CAFFE_ENFORCE_EQ(X.ndim(), filter.ndim());
const int M = filter.dim32(0);
CAFFE_ENFORCE_EQ(filter.dim32(filter.ndim() - 1), C / group_);
this->SetOutputSize(X, Y, filter.dim32(0));
auto sizes = this->GetOutputSize(X, filter.dim32(0));
Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<uint8_t>());
// The dimension of each kernel
const int kernel_dim = this->KernelDim_();
// The output image size is the spatial size of the output.

View File

@ -560,7 +560,6 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHW() {
const Tensor& X = InputTensorCPU_(INPUT);
auto& filter = InputTensorCPU_(FILTER);
Tensor* Y = OutputTensorCPU_(0);
const int N = X.dim32(0), C = X.dim32(1);
CAFFE_ENFORCE_EQ(X.dim(), filter.dim());
const int M = filter.dim32(0);
@ -578,7 +577,8 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHW() {
0,
"The number of output channels is not divisible by group.");
ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, filter.dim32(0));
auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, filter.dim32(0));
Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
const vector<int> input_dims = GetDims(X);
const vector<int> output_dims = GetDims(*Y);
@ -1418,7 +1418,6 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWC() {
const Tensor& X = InputTensorCPU_(INPUT);
auto& filter = InputTensorCPU_(FILTER);
Tensor* Y = OutputTensorCPU_(0);
const int C = X.dim32(X.dim() - 1);
const int G = group_;
CAFFE_ENFORCE_EQ(X.dim(), filter.dim());
@ -1435,7 +1434,8 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNHWC() {
CAFFE_ENFORCE_EQ(
M % G, 0, "The number of output channels is not divisible by group.");
ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, filter.dim32(0));
auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, filter.dim32(0));
Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
// The col buffer is stored in HWC order as well - kernel_dim, and the height
// and width.

View File

@ -69,6 +69,12 @@ class ConvPoolDNNLowPOpBase : public ConvPoolOpBase<CPUContext> {
return &Outputs()[idx]->template GetMutable<int8::Int8TensorCPU>()->t;
}
Tensor* OutputTensorCPU_(int idx, at::IntList dims, at::TensorOptions options) {
auto* t = &Outputs()[idx]->template GetMutable<int8::Int8TensorCPU>()->t;
ReinitializeTensor(t, dims, options.device(CPU));
return t;
}
T* GetQuantizedOutputData_() {
return OutputTensorCPU_(0)->template mutable_data<T>();
}

View File

@ -122,6 +122,16 @@ class DNNLowPOp : public Operator<CPUContext> {
}
}
Tensor* OutputTensorCPU_(int idx, at::IntList dims, at::TensorOptions options) {
if (dequantize_output_) {
return Output(idx, dims, options.device(CPU));
} else {
auto* t = &Outputs()[idx]->template GetMutable<int8::Int8TensorCPU>()->t;
ReinitializeTensor(t, dims, options.device(CPU));
return t;
}
}
T* GetQuantizedOutputData_() {
if (dequantize_output_) {
out_temp_.resize(Output(0)->numel());

View File

@ -101,8 +101,8 @@ class AveragePoolDnnLowPOp final
GetOutputQuantizationParams_();
auto& X = InputTensorCPU_(0);
auto* Y = OutputTensorCPU_(0);
ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, X.dim32(1));
auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, X.dim32(1));
auto* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
T* Ydata = GetQuantizedOutputData_();
@ -239,9 +239,9 @@ class AveragePoolDnnLowPOp final
GetOutputQuantizationParams_();
auto& X = InputTensorCPU_(0);
auto* Y = OutputTensorCPU_(0);
int channels = X.dim32(X.ndim() - 1);
ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, channels);
auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, channels);
auto* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
T* Ydata = GetQuantizedOutputData_();
@ -398,8 +398,8 @@ class MaxPoolDnnLowPOp final : public ConvPoolDNNLowPOpBase<T, MaxPoolFp32Op> {
const T* Xdata = QuantizeInputIfNeeded(this, 0, in_qparams_[0], X_temp);
auto& X = InputTensorCPU_(0);
auto* Y = OutputTensorCPU_(0);
ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, X.dim32(1));
auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, X.dim32(1));
auto* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
T* Ydata = GetQuantizedOutputData_();
@ -544,9 +544,9 @@ class MaxPoolDnnLowPOp final : public ConvPoolDNNLowPOpBase<T, MaxPoolFp32Op> {
const T* Xdata = QuantizeInputIfNeeded(this, 0, in_qparams_[0], X_temp);
auto& X = InputTensorCPU_(0);
auto* Y = OutputTensorCPU_(0);
int channels = X.dim32(X.ndim() - 1);
ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, channels);
auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, channels);
auto* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
T* Ydata = GetQuantizedOutputData_();

View File

@ -442,7 +442,6 @@ class Depthwise3x3ConvOp final : public ConvPoolOpBase<CPUContext> {
bool RunOnDeviceWithOrderNCHW() override {
const Tensor& X = Input(0);
auto& filter = Input(1);
Tensor* Y = Output(0);
const int N = X.dim32(0), C = X.dim32(1);
CAFFE_ENFORCE_EQ(X.ndim(), filter.ndim());
const int M = filter.dim32(0);
@ -452,8 +451,8 @@ class Depthwise3x3ConvOp final : public ConvPoolOpBase<CPUContext> {
CAFFE_ENFORCE_EQ(C, this->group_);
CAFFE_ENFORCE_EQ(M, this->group_);
ConvPoolOpBase<CPUContext>::SetOutputSize(X, Y, filter.dim32(0));
Y->mutable_data<float>();
auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, filter.dim32(0));
Tensor* Y = Output(0, sizes, at::dtype<float>());
DepthwiseArgs args;
args.batch = X.dim32(0);