diff --git a/caffe2/cuda_rtc/pool_op_rtc_gpu.cc b/caffe2/cuda_rtc/pool_op_rtc_gpu.cc index f7e4e9d9a19..5b455811f00 100644 --- a/caffe2/cuda_rtc/pool_op_rtc_gpu.cc +++ b/caffe2/cuda_rtc/pool_op_rtc_gpu.cc @@ -196,8 +196,8 @@ class MaxPoolRTCOp final : public ConvPoolOpBase { bool RunOnDeviceWithOrderNCHW() override { auto& X = Input(0); - auto* Y = Output(0); - ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(1)); + auto output_sizes = ConvPoolOpBase::GetOutputSize(X, X.dim32(1)); + auto* Y = Output(0, output_sizes, at::dtype()); if (input_dims_ != X.sizes()) { // recompile diff --git a/caffe2/mobile/contrib/ios/mpscnn/mpscnn.mm b/caffe2/mobile/contrib/ios/mpscnn/mpscnn.mm index 7494fe13fb7..f556e9c7956 100644 --- a/caffe2/mobile/contrib/ios/mpscnn/mpscnn.mm +++ b/caffe2/mobile/contrib/ios/mpscnn/mpscnn.mm @@ -257,11 +257,10 @@ void computeOutputHW( int* OH, int* OW) { Tensor input = caffe2::empty({1, 1, H, W}, at::dtype().device(CPU)); - Tensor output(CPU); - op->SetOutputSize(input, &output, 1); - CAFFE_ENFORCE_EQ(output.dim(), 4); - *OH = output.size(2); - *OW = output.size(3); + auto sizes = op->GetOutputSize(input, 1); + CAFFE_ENFORCE_EQ(sizes.size(), 4); + *OH = sizes[2]; + *OW = sizes[3]; } constexpr int computeMPSAlignOffset(int kernel, int pad) { diff --git a/caffe2/operators/conv_op_cudnn.cc b/caffe2/operators/conv_op_cudnn.cc index c48c943afed..8c510f7387b 100644 --- a/caffe2/operators/conv_op_cudnn.cc +++ b/caffe2/operators/conv_op_cudnn.cc @@ -514,13 +514,13 @@ template bool CudnnConvOp::DoRunWithType() { auto& X = Input(INPUT); auto& filter = Input(FILTER); - auto* Y = Output(0); // Figure out the output shape CAFFE_ENFORCE(X.dim() >= 3 && X.dim() <= 5); CAFFE_ENFORCE(filter.dim() >= 3 && filter.dim() <= 5); const int M = filter.dim32(0); - ConvPoolOpBase::SetOutputSize(X, Y, M); + auto output_sizes = ConvPoolOpBase::GetOutputSize(X, M); + auto* Y = Output(0, output_sizes, at::dtype()); int N = 0, C = 0, H = 0, W = 0, D = 0, H_out = 0, W_out = 0, D_out = 0; int group_offset_X = 0, group_offset_Y = 0; diff --git a/caffe2/operators/conv_pool_op_base.h b/caffe2/operators/conv_pool_op_base.h index e403bfae95f..58ca0944924 100644 --- a/caffe2/operators/conv_pool_op_base.h +++ b/caffe2/operators/conv_pool_op_base.h @@ -208,7 +208,7 @@ class ConvPoolOpBase : public Operator { return size; } - // Sets the output size. The output channel is manually provided since + // Gets the output size. The output channel is manually provided since // it may not be identical to the input channels. // This function can be used in the forward functions to obtain the output // sizes. @@ -216,7 +216,25 @@ class ConvPoolOpBase : public Operator { // implementations that do not use first-class Tensor objects, such as the // MKL operator. One can still call this function with dummy // Tensor objects in order to obtain the sizes. - // TODO: passing sizes directly rather than Tensor + std::vector GetOutputSize(const Tensor& input, int output_channel) { + CAFFE_ENFORCE_GE(input.dim(), 2); + const int inner_size = input.size_from_dim(1); + CAFFE_ENFORCE_GT(inner_size, 0); + std::vector output_dims; + InferOutputSize64( + input.sizes(), + output_channel, + order_, + global_pooling_, + legacy_pad_, + dilation_, + stride_, + &kernel_, + &pads_, + &output_dims); + return output_dims; + } + void SetOutputSize(const Tensor& input, Tensor* output, int output_channel) { const int inner_size = input.size_from_dim(1); CAFFE_ENFORCE_GT(inner_size, 0); @@ -276,6 +294,45 @@ class ConvPoolOpBase : public Operator { } } + static void InferOutputSize64( + const at::IntList& input_dims, + const int output_channel, + const StorageOrder order, + const bool global_pooling, + const LegacyPadding legacy_pad, + const std::vector& dilation, + const std::vector& stride, + std::vector* kernel, + std::vector* pads, + std::vector* output_dims) { + CAFFE_ENFORCE_NE(order, StorageOrder::UNKNOWN); + const int ndim = input_dims.size() - 2; + output_dims->resize(ndim + 2); + output_dims->front() = input_dims.front(); + if (order == StorageOrder::NCHW) { + output_dims->at(1) = output_channel; + } else { + output_dims->back() = output_channel; + } + const int offset = order == StorageOrder::NCHW ? 2 : 1; + if (global_pooling) { + std::copy_n(input_dims.cbegin() + offset, ndim, kernel->begin()); + std::fill_n(output_dims->begin() + offset, ndim, 1LL); + } else { + for (int i = 0; i < ndim; ++i) { + ComputeSizeAndPad64( + input_dims[i + offset], + stride[i], + kernel->at(i), + dilation[i], + legacy_pad, + &pads->at(i), + &pads->at(i + ndim), + &output_dims->at(i + offset)); + } + } + } + // ComputePads could be used in backward functions to figure out the padding // values for the given input. void ComputePads(const vector& dims) { @@ -670,6 +727,85 @@ class ConvPoolOpBase : public Operator { } } + static inline void ComputeSizeAndPad64( + const int in_size, + const int stride, + const int kernel, + const int dilation, + LegacyPadding legacy_pad, + int* pad_head, + int* pad_tail, + int64_t* out_size) { + const int dkernel = dilation * (kernel - 1) + 1; + switch (legacy_pad) { + case LegacyPadding::NOTSET: + // We will just use the direct padding head and tail values, but we + // will verify that they are non-negative. + CAFFE_ENFORCE_GE(in_size + *pad_head + *pad_tail, dkernel); + *out_size = static_cast( + static_cast(in_size + *pad_head + *pad_tail - dkernel) / + stride + + 1); + break; + case LegacyPadding::VALID: + *pad_head = 0; + *pad_tail = 0; + *out_size = (in_size - dkernel) / stride + 1; + break; + case LegacyPadding::SAME: { + CAFFE_ENFORCE( + 1 == dilation, "Dilation not supported for legacy padding."); + int legacy_target_size = (in_size + stride - 1) / stride; + int pad_needed = (legacy_target_size - 1) * stride + kernel - in_size; + if (CAFFE2_PAD_HEAD_MORE) { + *pad_head = (pad_needed + 1) / 2; + } else { + *pad_head = pad_needed / 2; + } + *pad_tail = pad_needed - *pad_head; + *out_size = (in_size + pad_needed - dkernel) / stride + 1; + break; + } + case LegacyPadding::CAFFE_LEGACY_POOLING: + // This is in order to adapt Caffe's pooling padding case. In this case, + // we will only use pad_head and will compute pad_tail to match the + // old caffe pooling strategy. Also see caffe2_legacy.proto for more + // details. + CAFFE_ENFORCE_GE(*pad_head, 0); + // Here, notice that caffe casts UP while caffe2 casts DOWN for the + // output size computation. + *out_size = std::ceil( + static_cast(in_size + *pad_head * 2 - kernel) / stride + 1); + // If we have padding, caffe also ensures that the last pooling starts + // strictly inside the image (instead of at the padding); otherwise clip + // the last. + if (*pad_head > 0 && (*out_size - 1) * stride >= in_size + *pad_head) { + --*out_size; + } + // Now, compare the output size with the standard Caffe2 output size. + // The + // caffe2 standard output size should always be no larger than the + // output + // size of caffe. + int standard_out_size = static_cast( + static_cast(in_size + *pad_head * 2 - kernel) / stride + 1); + CAFFE_ENFORCE_GE( + *out_size, + standard_out_size, + "This should never happen. If this happens, double check the logic " + "above."); + if (*out_size > standard_out_size) { + LOG(WARNING) + << "You are hitting a case where Caffe's legacy padding calculation " + "is hit. This leads to inefficient and sometimes incorrect " + "results. We are keeping this behavior for backward compatibility" + ", but you are strongly recommended to move away from it."; + } + *pad_tail = *pad_head + stride * (*out_size - standard_out_size); + break; + } + } + // Accessors for 2D conv params. inline int pad_t() const { diff --git a/caffe2/operators/depthwise_3x3_conv_op_cudnn.cu b/caffe2/operators/depthwise_3x3_conv_op_cudnn.cu index 36b54b95905..5b761493c41 100644 --- a/caffe2/operators/depthwise_3x3_conv_op_cudnn.cu +++ b/caffe2/operators/depthwise_3x3_conv_op_cudnn.cu @@ -288,7 +288,6 @@ class Depthwise3x3ConvOp final : public ConvPoolOpBase { bool RunOnDeviceWithOrderNCHW() override { const Tensor& X = Input(0); auto& filter = Input(1); - Tensor* Y = Output(0); const int N = X.dim32(0), C = X.dim32(1); CAFFE_ENFORCE_EQ(X.dim(), filter.dim()); const int M = filter.dim32(0); @@ -300,7 +299,8 @@ class Depthwise3x3ConvOp final : public ConvPoolOpBase { CAFFE_ENFORCE_EQ(this->kernel_w(), 3); CAFFE_ENFORCE_EQ(this->kernel_h(), 3); CAFFE_ENFORCE_EQ(this->stride_h(), this->stride_w()); - ConvPoolOpBase::SetOutputSize(X, Y, filter.dim32(0)); + auto sizes = ConvPoolOpBase::GetOutputSize(X, filter.dim32(0)); + Tensor* Y = Output(0, sizes, at::dtype()); DepthwiseArgs args; args.batch = X.dim32(0); args.in_rows = X.dim32(2); @@ -458,7 +458,7 @@ class Depthwise3x3ConvGradientOp final : public ConvPoolOpBase { M, dY.dim32(2), dY.dim32(3))); - + auto* dbias = Output(BIAS_OR_INPUT_GRAD, {M}, at::dtype()); CUDNN_ENFORCE(cudnnConvolutionBackwardBias( cudnn_wrapper_.inline_cudnn_handle(), diff --git a/caffe2/operators/hip/conv_op_miopen.hip b/caffe2/operators/hip/conv_op_miopen.hip index d0836113a20..1bdca48bb49 100644 --- a/caffe2/operators/hip/conv_op_miopen.hip +++ b/caffe2/operators/hip/conv_op_miopen.hip @@ -207,7 +207,6 @@ template bool MIOPENConvOp::DoRunWithType() { auto& X = Input(INPUT); auto& Weight = Input(FILTER); - auto* Y = Output(0); // Figure out the output shape CAFFE_ENFORCE(X.ndim() >= 3 && X.ndim() <= 5); @@ -216,7 +215,8 @@ bool MIOPENConvOp::DoRunWithType() { "Conv op with MIOpen engine is supported only for 2D convolutions"); const int M = Weight.dim32(0); - ConvPoolOpBase::SetOutputSize(X, Y, M); + auto sizes = ConvPoolOpBase::GetOutputSize(X, M); + auto* Y = Output(0, sizes, at::dtype()); int N = X.dim32(0); int C = X.dim32(1); diff --git a/caffe2/operators/hip/pool_op_miopen.hip b/caffe2/operators/hip/pool_op_miopen.hip index 614b6cf09bc..c1d5ee387b3 100644 --- a/caffe2/operators/hip/pool_op_miopen.hip +++ b/caffe2/operators/hip/pool_op_miopen.hip @@ -61,7 +61,6 @@ class MIOPENPoolOp : public ConvPoolOpBase { template bool DoRunWithType() { auto& X = Input(0); - auto* Y = Output(0); int N = 0, C = 0, H = 0, W = 0, D = 0; int N_out = 0, C_out = 0, H_out = 0, W_out = 0; CAFFE_ENFORCE(X.ndim() >= 4 && X.ndim() <= 5); @@ -69,7 +68,8 @@ class MIOPENPoolOp : public ConvPoolOpBase { C = X.dim32(1); H = X.dim32(2); W = X.ndim() > 3 ? X.dim32(3) : 1; - ConvPoolOpBase::SetOutputSize(X, Y, C); + auto sizes = ConvPoolOpBase::GetOutputSize(X, C); + auto* Y = Output(0, sizes, at::dtype()); N_out = Y->dim32(0); C_out = Y->dim32(1); diff --git a/caffe2/operators/max_pool_with_index.cu b/caffe2/operators/max_pool_with_index.cu index 31513b53c52..cefa831e25e 100644 --- a/caffe2/operators/max_pool_with_index.cu +++ b/caffe2/operators/max_pool_with_index.cu @@ -108,9 +108,10 @@ __global__ void MaxPoolBackward( template bool MaxPoolWithIndexOp::DoRunWithType() { auto& X = Input(0); - auto* Y = Output(0); - ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(1)); + auto sizes = ConvPoolOpBase::GetOutputSize(X, X.dim32(1)); + auto* Y = Output(0, sizes, at::dtype()); + int output_size = Y->numel(); auto* mask = Output(1, {output_size}, at::dtype()); diff --git a/caffe2/operators/pad_op_gpu.cu b/caffe2/operators/pad_op_gpu.cu index c623633340b..f9c37af11da 100644 --- a/caffe2/operators/pad_op_gpu.cu +++ b/caffe2/operators/pad_op_gpu.cu @@ -251,12 +251,13 @@ __global__ void PadImageGradientEdgeNHWC( template <> bool PadImageOp::RunOnDeviceWithOrderNCHW() { auto& X = Input(0); - auto* Y = Output(0); const int num = X.dim32(0); const int channels = X.dim32(1); const int height = X.dim32(2); const int width = X.dim32(3); - ConvPoolOpBase::SetOutputSize(X, Y, channels); + auto sizes = ConvPoolOpBase::GetOutputSize(X, channels); + auto* Y = Output(0, sizes, at::dtype()); + const int output_size = Y->numel(); const int padded_height = Y->dim32(2); const int padded_width = Y->dim32(3); @@ -327,12 +328,13 @@ bool PadImageOp::RunOnDeviceWithOrderNCHW() { template<> bool PadImageOp::RunOnDeviceWithOrderNHWC() { auto& X = Input(0); - auto* Y = Output(0); const int num = X.dim32(0); const int height = X.dim32(1); const int width = X.dim32(2); const int channels = X.dim32(3); - ConvPoolOpBase::SetOutputSize(X, Y, channels); + auto sizes = ConvPoolOpBase::GetOutputSize(X, channels); + auto* Y = Output(0, sizes, at::dtype()); + const int output_size = Y->numel(); const int padded_height = Y->dim32(1); const int padded_width = Y->dim32(2); @@ -403,7 +405,7 @@ bool PadImageOp::RunOnDeviceWithOrderNHWC() { template<> bool PadImageGradientOp::RunOnDeviceWithOrderNCHW() { auto& dY = Input(0); - + auto* dX = Output(0, { dY.dim32(0), dY.dim32(1), dY.dim32(2) - pad_t() - pad_b(), @@ -483,7 +485,7 @@ bool PadImageGradientOp::RunOnDeviceWithOrderNCHW() { template<> bool PadImageGradientOp::RunOnDeviceWithOrderNHWC() { auto& dY = Input(0); - + auto* dX = Output(0, { dY.dim32(0), dY.dim32(1) - pad_t() - pad_b(), dY.dim32(2) - pad_l() - pad_r(), diff --git a/caffe2/operators/pool_op_cudnn.cc b/caffe2/operators/pool_op_cudnn.cc index 0e1160a023d..e65680148c7 100644 --- a/caffe2/operators/pool_op_cudnn.cc +++ b/caffe2/operators/pool_op_cudnn.cc @@ -100,11 +100,11 @@ class CuDNNPoolOp final : public ConvPoolOpBase { template bool DoRunWithType() { const auto& X = Input(0); - auto* Y = Output(0); const int ndim = X.dim(); const int N = X.dim32(0); const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1); - ConvPoolOpBase::SetOutputSize(X, Y, C); + auto sizes = ConvPoolOpBase::GetOutputSize(X, C); + auto* Y = Output(0, sizes, at::dtype()); const T* X_data = X.template data(); T* Y_data = Y->template mutable_data(); diff --git a/caffe2/quantization/server/conv_dnnlowp_acc16_op.cc b/caffe2/quantization/server/conv_dnnlowp_acc16_op.cc index b3713377bee..b339e522403 100644 --- a/caffe2/quantization/server/conv_dnnlowp_acc16_op.cc +++ b/caffe2/quantization/server/conv_dnnlowp_acc16_op.cc @@ -102,8 +102,8 @@ bool ConvDNNLowPAcc16Op::GetQuantizationParameters_() { const Tensor& X = InputTensorCPU_(INPUT); int N = X.dim32(0); - Tensor* Y = OutputTensorCPU_(0); - this->SetOutputSize(X, Y, filter.dim32(0)); + auto sizes = this->GetOutputSize(X, filter.dim32(0)); + Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype()); const int output_image_size = this->GetDimsSize(*Y); if (N * output_image_size < FLAGS_caffe2_dnnlowp_acc16_m_threshold) { @@ -228,7 +228,6 @@ bool ConvDNNLowPAcc16Op::RunOnDeviceWithOrderNCHW() { const Tensor& X = InputTensorCPU_(INPUT); auto& filter = InputTensorCPU_(FILTER); - Tensor* Y = OutputTensorCPU_(0); const int N = X.dim32(0), C = X.dim32(1); CAFFE_ENFORCE_EQ(X.ndim(), filter.ndim()); const int M = filter.dim32(0); @@ -246,7 +245,8 @@ bool ConvDNNLowPAcc16Op::RunOnDeviceWithOrderNCHW() { 0, "The number of output channels is not divisible by group."); - this->SetOutputSize(X, Y, filter.dim32(0)); + auto sizes = this->GetOutputSize(X, filter.dim32(0)); + Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype()); const vector input_dims = GetDims(X); const vector output_dims = GetDims(*Y); @@ -618,14 +618,14 @@ bool ConvDNNLowPAcc16Op::RunOnDeviceWithOrderNHWC() { const Tensor& X = InputTensorCPU_(INPUT); auto& filter = InputTensorCPU_(FILTER); - Tensor* Y = OutputTensorCPU_(0); const int N = X.dim32(0), C = X.dim32(X.ndim() - 1); CAFFE_ENFORCE_EQ(X.ndim(), filter.ndim()); const int M = filter.dim32(0); CAFFE_ENFORCE_EQ(filter.dim32(filter.ndim() - 1), C / group_); - this->SetOutputSize(X, Y, filter.dim32(0)); + auto sizes = this->GetOutputSize(X, filter.dim32(0)); + Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype()); // The dimension of each kernel const int kernel_dim = this->KernelDim_(); // The output image size is the spatial size of the output. diff --git a/caffe2/quantization/server/conv_dnnlowp_op.cc b/caffe2/quantization/server/conv_dnnlowp_op.cc index e75789f4855..bb2aeeef21e 100644 --- a/caffe2/quantization/server/conv_dnnlowp_op.cc +++ b/caffe2/quantization/server/conv_dnnlowp_op.cc @@ -560,7 +560,6 @@ bool ConvDNNLowPOp::RunOnDeviceWithOrderNCHW() { const Tensor& X = InputTensorCPU_(INPUT); auto& filter = InputTensorCPU_(FILTER); - Tensor* Y = OutputTensorCPU_(0); const int N = X.dim32(0), C = X.dim32(1); CAFFE_ENFORCE_EQ(X.dim(), filter.dim()); const int M = filter.dim32(0); @@ -578,7 +577,8 @@ bool ConvDNNLowPOp::RunOnDeviceWithOrderNCHW() { 0, "The number of output channels is not divisible by group."); - ConvPoolOpBase::SetOutputSize(X, Y, filter.dim32(0)); + auto sizes = ConvPoolOpBase::GetOutputSize(X, filter.dim32(0)); + Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype()); const vector input_dims = GetDims(X); const vector output_dims = GetDims(*Y); @@ -1418,7 +1418,6 @@ bool ConvDNNLowPOp::RunOnDeviceWithOrderNHWC() { const Tensor& X = InputTensorCPU_(INPUT); auto& filter = InputTensorCPU_(FILTER); - Tensor* Y = OutputTensorCPU_(0); const int C = X.dim32(X.dim() - 1); const int G = group_; CAFFE_ENFORCE_EQ(X.dim(), filter.dim()); @@ -1435,7 +1434,8 @@ bool ConvDNNLowPOp::RunOnDeviceWithOrderNHWC() { CAFFE_ENFORCE_EQ( M % G, 0, "The number of output channels is not divisible by group."); - ConvPoolOpBase::SetOutputSize(X, Y, filter.dim32(0)); + auto sizes = ConvPoolOpBase::GetOutputSize(X, filter.dim32(0)); + Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype()); // The col buffer is stored in HWC order as well - kernel_dim, and the height // and width. diff --git a/caffe2/quantization/server/conv_pool_dnnlowp_op_base.h b/caffe2/quantization/server/conv_pool_dnnlowp_op_base.h index d6359553063..4a2c72bf6ac 100644 --- a/caffe2/quantization/server/conv_pool_dnnlowp_op_base.h +++ b/caffe2/quantization/server/conv_pool_dnnlowp_op_base.h @@ -69,6 +69,12 @@ class ConvPoolDNNLowPOpBase : public ConvPoolOpBase { return &Outputs()[idx]->template GetMutable()->t; } + Tensor* OutputTensorCPU_(int idx, at::IntList dims, at::TensorOptions options) { + auto* t = &Outputs()[idx]->template GetMutable()->t; + ReinitializeTensor(t, dims, options.device(CPU)); + return t; + } + T* GetQuantizedOutputData_() { return OutputTensorCPU_(0)->template mutable_data(); } diff --git a/caffe2/quantization/server/dnnlowp_op.h b/caffe2/quantization/server/dnnlowp_op.h index 8327b278f3a..88a5a1dc153 100644 --- a/caffe2/quantization/server/dnnlowp_op.h +++ b/caffe2/quantization/server/dnnlowp_op.h @@ -122,6 +122,16 @@ class DNNLowPOp : public Operator { } } + Tensor* OutputTensorCPU_(int idx, at::IntList dims, at::TensorOptions options) { + if (dequantize_output_) { + return Output(idx, dims, options.device(CPU)); + } else { + auto* t = &Outputs()[idx]->template GetMutable()->t; + ReinitializeTensor(t, dims, options.device(CPU)); + return t; + } + } + T* GetQuantizedOutputData_() { if (dequantize_output_) { out_temp_.resize(Output(0)->numel()); diff --git a/caffe2/quantization/server/pool_dnnlowp_op.cc b/caffe2/quantization/server/pool_dnnlowp_op.cc index 0dda848bed4..7d6ded95d71 100644 --- a/caffe2/quantization/server/pool_dnnlowp_op.cc +++ b/caffe2/quantization/server/pool_dnnlowp_op.cc @@ -101,8 +101,8 @@ class AveragePoolDnnLowPOp final GetOutputQuantizationParams_(); auto& X = InputTensorCPU_(0); - auto* Y = OutputTensorCPU_(0); - ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(1)); + auto sizes = ConvPoolOpBase::GetOutputSize(X, X.dim32(1)); + auto* Y = OutputTensorCPU_(0, sizes, at::dtype()); T* Ydata = GetQuantizedOutputData_(); @@ -239,9 +239,9 @@ class AveragePoolDnnLowPOp final GetOutputQuantizationParams_(); auto& X = InputTensorCPU_(0); - auto* Y = OutputTensorCPU_(0); int channels = X.dim32(X.ndim() - 1); - ConvPoolOpBase::SetOutputSize(X, Y, channels); + auto sizes = ConvPoolOpBase::GetOutputSize(X, channels); + auto* Y = OutputTensorCPU_(0, sizes, at::dtype()); T* Ydata = GetQuantizedOutputData_(); @@ -398,8 +398,8 @@ class MaxPoolDnnLowPOp final : public ConvPoolDNNLowPOpBase { const T* Xdata = QuantizeInputIfNeeded(this, 0, in_qparams_[0], X_temp); auto& X = InputTensorCPU_(0); - auto* Y = OutputTensorCPU_(0); - ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(1)); + auto sizes = ConvPoolOpBase::GetOutputSize(X, X.dim32(1)); + auto* Y = OutputTensorCPU_(0, sizes, at::dtype()); T* Ydata = GetQuantizedOutputData_(); @@ -544,9 +544,9 @@ class MaxPoolDnnLowPOp final : public ConvPoolDNNLowPOpBase { const T* Xdata = QuantizeInputIfNeeded(this, 0, in_qparams_[0], X_temp); auto& X = InputTensorCPU_(0); - auto* Y = OutputTensorCPU_(0); int channels = X.dim32(X.ndim() - 1); - ConvPoolOpBase::SetOutputSize(X, Y, channels); + auto sizes = ConvPoolOpBase::GetOutputSize(X, channels); + auto* Y = OutputTensorCPU_(0, sizes, at::dtype()); T* Ydata = GetQuantizedOutputData_(); diff --git a/caffe2/share/contrib/depthwise/depthwise3x3_conv_op.cc b/caffe2/share/contrib/depthwise/depthwise3x3_conv_op.cc index 37460c84a11..b7bab4f41ab 100644 --- a/caffe2/share/contrib/depthwise/depthwise3x3_conv_op.cc +++ b/caffe2/share/contrib/depthwise/depthwise3x3_conv_op.cc @@ -442,7 +442,6 @@ class Depthwise3x3ConvOp final : public ConvPoolOpBase { bool RunOnDeviceWithOrderNCHW() override { const Tensor& X = Input(0); auto& filter = Input(1); - Tensor* Y = Output(0); const int N = X.dim32(0), C = X.dim32(1); CAFFE_ENFORCE_EQ(X.ndim(), filter.ndim()); const int M = filter.dim32(0); @@ -452,8 +451,8 @@ class Depthwise3x3ConvOp final : public ConvPoolOpBase { CAFFE_ENFORCE_EQ(C, this->group_); CAFFE_ENFORCE_EQ(M, this->group_); - ConvPoolOpBase::SetOutputSize(X, Y, filter.dim32(0)); - Y->mutable_data(); + auto sizes = ConvPoolOpBase::GetOutputSize(X, filter.dim32(0)); + Tensor* Y = Output(0, sizes, at::dtype()); DepthwiseArgs args; args.batch = X.dim32(0);