mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove "numel > 0" check in Caffe2 conv_transpose_unpool_op_base
D15766739 took care of the main ConvTranspose op, but this removes the check and will allow Int8ConvTranspose to handle numel == 0.
This commit is contained in:
parent
a9bb68d436
commit
81d45eed50
|
|
@ -360,7 +360,7 @@ bool ConvTransposeGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
|
|||
}
|
||||
math::Set<T, Context>(filter.numel(), T(0), dfilter_data, &context_);
|
||||
|
||||
if (X.numel() == 0) {
|
||||
if (X.numel() == 0 && dbias_data != nullptr) {
|
||||
VLOG(2) << "Number of elements is 0 in ConvTrasposeOp";
|
||||
math::Set<T, Context>(C, T(0), dbias_data, &context_);
|
||||
return true;
|
||||
|
|
@ -523,7 +523,7 @@ bool ConvTransposeGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
|
|||
}
|
||||
math::Set<T, Context>(filter.numel(), T(0), dfilter_data, &context_);
|
||||
|
||||
if (X.numel() == 0) {
|
||||
if (X.numel() && dbias_data != nullptr) {
|
||||
VLOG(2) << "Number of elements is 0 in ConvTrasposeOp";
|
||||
math::Set<T, Context>(C, T(0), dbias_data, &context_);
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -555,6 +555,9 @@ bool ConvTransposeMobileOp<T, Context>::RunOnDeviceWithOrderNCHW() {
|
|||
|
||||
auto sizes = ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
|
||||
Tensor* Y = Output(0, sizes, at::dtype<T>());
|
||||
if (N == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const int outputH = Y->dim32(2);
|
||||
const int outputW = Y->dim32(3);
|
||||
|
|
|
|||
|
|
@ -136,7 +136,6 @@ class ConvTransposeUnpoolBase : public Operator<Context> {
|
|||
// Gets the output size. The output channel is manually specified.
|
||||
std::vector<int64_t> GetOutputSize(const Tensor& input, int output_channel) {
|
||||
CAFFE_ENFORCE(4 == input.dim());
|
||||
CAFFE_ENFORCE(input.numel() > 0);
|
||||
int N = input.dim32(0);
|
||||
bool channel_first = false; // initialized to suppress compiler warning.
|
||||
int H = 0, W = 0; // initialized to suppress compiler warning.
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase):
|
|||
size=st.integers(7, 10),
|
||||
input_channels=st.integers(1, 8),
|
||||
output_channels=st.integers(1, 8),
|
||||
batch_size=st.integers(1, 3),
|
||||
batch_size=st.integers(0, 3),
|
||||
engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
|
||||
shared_buffer=st.booleans(),
|
||||
use_bias=st.booleans(),
|
||||
|
|
@ -90,7 +90,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase):
|
|||
size=st.integers(7, 10),
|
||||
input_channels=st.integers(1, 8),
|
||||
output_channels=st.integers(1, 8),
|
||||
batch_size=st.integers(1, 3),
|
||||
batch_size=st.integers(0, 3),
|
||||
engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
|
||||
shared_buffer=st.booleans(),
|
||||
use_bias=st.booleans(),
|
||||
|
|
@ -166,7 +166,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase):
|
|||
size=st.integers(7, 10),
|
||||
input_channels=st.integers(1, 8),
|
||||
output_channels=st.integers(1, 8),
|
||||
batch_size=st.integers(1, 3),
|
||||
batch_size=st.integers(0, 3),
|
||||
engine=st.sampled_from(["", "BLOCK"]),
|
||||
use_bias=st.booleans(),
|
||||
**hu.gcs)
|
||||
|
|
@ -235,7 +235,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase):
|
|||
size=st.integers(7, 10),
|
||||
input_channels=st.integers(1, 8),
|
||||
output_channels=st.integers(1, 8),
|
||||
batch_size=st.integers(1, 3),
|
||||
batch_size=st.integers(0, 3),
|
||||
order=st.sampled_from(["NCHW", "NHWC"]),
|
||||
engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
|
||||
use_bias=st.booleans(),
|
||||
|
|
@ -303,7 +303,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase):
|
|||
size=st.integers(7, 10),
|
||||
input_channels=st.integers(1, 8),
|
||||
output_channels=st.integers(1, 8),
|
||||
batch_size=st.integers(1, 3),
|
||||
batch_size=st.integers(0, 3),
|
||||
order=st.sampled_from(["NCHW", "NHWC"]),
|
||||
engine=st.sampled_from(["", "BLOCK"]),
|
||||
use_bias=st.booleans(),
|
||||
|
|
@ -368,7 +368,7 @@ class TestConvolutionTranspose(hu.HypothesisTestCase):
|
|||
size=st.integers(7, 10),
|
||||
input_channels=st.integers(1, 8),
|
||||
output_channels=st.integers(1, 8),
|
||||
batch_size=st.integers(1, 4),
|
||||
batch_size=st.integers(0, 4),
|
||||
group=st.integers(1, 4),
|
||||
order=st.sampled_from(["NCHW", "NHWC"]),
|
||||
engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user