diff --git a/tensorflow/core/kernels/ctc_loss_op.cc b/tensorflow/core/kernels/ctc_loss_op.cc index 426382edeca..a1f60019141 100644 --- a/tensorflow/core/kernels/ctc_loss_op.cc +++ b/tensorflow/core/kernels/ctc_loss_op.cc @@ -88,6 +88,9 @@ class CTCLossOp : public OpKernel { labels_indices->shape().DebugString(), " vs. ", labels_values->shape().DebugString())); + OP_REQUIRES(ctx, batch_size != 0, + errors::InvalidArgument("batch_size must not be 0")); + TensorShape labels_shape({batch_size, max_time}); std::vector order{0, 1}; sparse::SparseTensor labels_sp(*labels_indices, *labels_values, diff --git a/tensorflow/python/kernel_tests/ctc_loss_op_test.py b/tensorflow/python/kernel_tests/ctc_loss_op_test.py index 5b93f90a799..18e92162b93 100644 --- a/tensorflow/python/kernel_tests/ctc_loss_op_test.py +++ b/tensorflow/python/kernel_tests/ctc_loss_op_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import ctc_ops from tensorflow.python.ops import gradients_impl @@ -260,6 +261,18 @@ class CTCLossTest(test.TestCase): "explicitly disabled"): _ = gradients_impl._hessian_vector_product(loss, [inputs_t], v) + def testEmptyBatch(self): + inputs = constant_op.constant([], dtype=dtypes.float32, shape=(1, 0, 2)) + sequence_lengths = constant_op.constant([], dtype=dtypes.int32) + labels = sparse_tensor.SparseTensor( + indices=constant_op.constant([], shape=(0, 2), dtype=dtypes.int64), + values=constant_op.constant([], shape=(0,), dtype=dtypes.int32), + dense_shape=[5, 5]) + + with self.test_session(use_gpu=False) as sess: + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "batch_size must not be 0"): + sess.run(ctc_ops.ctc_loss(labels, inputs, sequence_lengths)) if __name__ == "__main__": test.main()