Prevent ctc_loss op from segfaulting when given empty batch.

PiperOrigin-RevId: 163663460
This commit is contained in:
A. Unique TensorFlower 2017-07-31 01:07:41 -07:00 committed by TensorFlower Gardener
parent e17650b698
commit f19bb3bebf
2 changed files with 16 additions and 0 deletions

View File

@ -88,6 +88,9 @@ class CTCLossOp : public OpKernel {
labels_indices->shape().DebugString(), " vs. ",
labels_values->shape().DebugString()));
OP_REQUIRES(ctx, batch_size != 0,
errors::InvalidArgument("batch_size must not be 0"));
TensorShape labels_shape({batch_size, max_time});
std::vector<int64> order{0, 1};
sparse::SparseTensor labels_sp(*labels_indices, *labels_values,

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import ctc_ops
from tensorflow.python.ops import gradients_impl
@ -260,6 +261,18 @@ class CTCLossTest(test.TestCase):
"explicitly disabled"):
_ = gradients_impl._hessian_vector_product(loss, [inputs_t], v)
def testEmptyBatch(self):
inputs = constant_op.constant([], dtype=dtypes.float32, shape=(1, 0, 2))
sequence_lengths = constant_op.constant([], dtype=dtypes.int32)
labels = sparse_tensor.SparseTensor(
indices=constant_op.constant([], shape=(0, 2), dtype=dtypes.int64),
values=constant_op.constant([], shape=(0,), dtype=dtypes.int32),
dense_shape=[5, 5])
with self.test_session(use_gpu=False) as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"batch_size must not be 0"):
sess.run(ctc_ops.ctc_loss(labels, inputs, sequence_lengths))
if __name__ == "__main__":
test.main()