mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Prevent ctc_loss op from segfaulting when given empty batch.
PiperOrigin-RevId: 163663460
This commit is contained in:
parent
e17650b698
commit
f19bb3bebf
|
|
@ -88,6 +88,9 @@ class CTCLossOp : public OpKernel {
|
|||
labels_indices->shape().DebugString(), " vs. ",
|
||||
labels_values->shape().DebugString()));
|
||||
|
||||
OP_REQUIRES(ctx, batch_size != 0,
|
||||
errors::InvalidArgument("batch_size must not be 0"));
|
||||
|
||||
TensorShape labels_shape({batch_size, max_time});
|
||||
std::vector<int64> order{0, 1};
|
||||
sparse::SparseTensor labels_sp(*labels_indices, *labels_values,
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import numpy as np
|
|||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import ctc_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
|
|
@ -260,6 +261,18 @@ class CTCLossTest(test.TestCase):
|
|||
"explicitly disabled"):
|
||||
_ = gradients_impl._hessian_vector_product(loss, [inputs_t], v)
|
||||
|
||||
def testEmptyBatch(self):
|
||||
inputs = constant_op.constant([], dtype=dtypes.float32, shape=(1, 0, 2))
|
||||
sequence_lengths = constant_op.constant([], dtype=dtypes.int32)
|
||||
labels = sparse_tensor.SparseTensor(
|
||||
indices=constant_op.constant([], shape=(0, 2), dtype=dtypes.int64),
|
||||
values=constant_op.constant([], shape=(0,), dtype=dtypes.int32),
|
||||
dense_shape=[5, 5])
|
||||
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||
"batch_size must not be 0"):
|
||||
sess.run(ctc_ops.ctc_loss(labels, inputs, sequence_lengths))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user