Remove "numel > 0" check in Caffe2 conv_transpose_unpool_op_base

D15766739 took care of the main ConvTranspose op, but this removes the
check and will allow Int8ConvTranspose to handle numel == 0.
This commit is contained in:
David Reiss 2019-08-29 16:55:21 -07:00
parent a9bb68d436
commit 81d45eed50
4 changed files with 11 additions and 9 deletions

View File

@ -360,7 +360,7 @@ bool ConvTransposeGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
}
math::Set<T, Context>(filter.numel(), T(0), dfilter_data, &context_);
if (X.numel() == 0) {
if (X.numel() == 0 && dbias_data != nullptr) {
VLOG(2) << "Number of elements is 0 in ConvTrasposeOp";
math::Set<T, Context>(C, T(0), dbias_data, &context_);
return true;
@ -523,7 +523,7 @@ bool ConvTransposeGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
}
math::Set<T, Context>(filter.numel(), T(0), dfilter_data, &context_);
if (X.numel() == 0) {
if (X.numel() && dbias_data != nullptr) {
VLOG(2) << "Number of elements is 0 in ConvTrasposeOp";
math::Set<T, Context>(C, T(0), dbias_data, &context_);
return true;

View File

@ -555,6 +555,9 @@ bool ConvTransposeMobileOp<T, Context>::RunOnDeviceWithOrderNCHW() {
auto sizes = ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
Tensor* Y = Output(0, sizes, at::dtype<T>());
if (N == 0) {
return true;
}
const int outputH = Y->dim32(2);
const int outputW = Y->dim32(3);

View File

@ -136,7 +136,6 @@ class ConvTransposeUnpoolBase : public Operator<Context> {
// Gets the output size. The output channel is manually specified.
std::vector<int64_t> GetOutputSize(const Tensor& input, int output_channel) {
CAFFE_ENFORCE(4 == input.dim());
CAFFE_ENFORCE(input.numel() > 0);
int N = input.dim32(0);
bool channel_first = false; // initialized to suppress compiler warning.
int H = 0, W = 0; // initialized to suppress compiler warning.

View File

@ -20,7 +20,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase):
size=st.integers(7, 10),
input_channels=st.integers(1, 8),
output_channels=st.integers(1, 8),
batch_size=st.integers(1, 3),
batch_size=st.integers(0, 3),
engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
shared_buffer=st.booleans(),
use_bias=st.booleans(),
@ -90,7 +90,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase):
size=st.integers(7, 10),
input_channels=st.integers(1, 8),
output_channels=st.integers(1, 8),
batch_size=st.integers(1, 3),
batch_size=st.integers(0, 3),
engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
shared_buffer=st.booleans(),
use_bias=st.booleans(),
@ -166,7 +166,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase):
size=st.integers(7, 10),
input_channels=st.integers(1, 8),
output_channels=st.integers(1, 8),
batch_size=st.integers(1, 3),
batch_size=st.integers(0, 3),
engine=st.sampled_from(["", "BLOCK"]),
use_bias=st.booleans(),
**hu.gcs)
@ -235,7 +235,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase):
size=st.integers(7, 10),
input_channels=st.integers(1, 8),
output_channels=st.integers(1, 8),
batch_size=st.integers(1, 3),
batch_size=st.integers(0, 3),
order=st.sampled_from(["NCHW", "NHWC"]),
engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
use_bias=st.booleans(),
@ -303,7 +303,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase):
size=st.integers(7, 10),
input_channels=st.integers(1, 8),
output_channels=st.integers(1, 8),
batch_size=st.integers(1, 3),
batch_size=st.integers(0, 3),
order=st.sampled_from(["NCHW", "NHWC"]),
engine=st.sampled_from(["", "BLOCK"]),
use_bias=st.booleans(),
@ -368,7 +368,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase):
size=st.integers(7, 10),
input_channels=st.integers(1, 8),
output_channels=st.integers(1, 8),
batch_size=st.integers(1, 4),
batch_size=st.integers(0, 4),
group=st.integers(1, 4),
order=st.sampled_from(["NCHW", "NHWC"]),
engine=st.sampled_from(["", "CUDNN", "BLOCK"]),