mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Prevent ctc_loss op from segfaulting when given empty batch.
PiperOrigin-RevId: 163663460
This commit is contained in:
parent
e17650b698
commit
f19bb3bebf
|
|
@ -88,6 +88,9 @@ class CTCLossOp : public OpKernel {
|
||||||
labels_indices->shape().DebugString(), " vs. ",
|
labels_indices->shape().DebugString(), " vs. ",
|
||||||
labels_values->shape().DebugString()));
|
labels_values->shape().DebugString()));
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx, batch_size != 0,
|
||||||
|
errors::InvalidArgument("batch_size must not be 0"));
|
||||||
|
|
||||||
TensorShape labels_shape({batch_size, max_time});
|
TensorShape labels_shape({batch_size, max_time});
|
||||||
std::vector<int64> order{0, 1};
|
std::vector<int64> order{0, 1};
|
||||||
sparse::SparseTensor labels_sp(*labels_indices, *labels_values,
|
sparse::SparseTensor labels_sp(*labels_indices, *labels_values,
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops import ctc_ops
|
from tensorflow.python.ops import ctc_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
|
|
@ -260,6 +261,18 @@ class CTCLossTest(test.TestCase):
|
||||||
"explicitly disabled"):
|
"explicitly disabled"):
|
||||||
_ = gradients_impl._hessian_vector_product(loss, [inputs_t], v)
|
_ = gradients_impl._hessian_vector_product(loss, [inputs_t], v)
|
||||||
|
|
||||||
|
def testEmptyBatch(self):
|
||||||
|
inputs = constant_op.constant([], dtype=dtypes.float32, shape=(1, 0, 2))
|
||||||
|
sequence_lengths = constant_op.constant([], dtype=dtypes.int32)
|
||||||
|
labels = sparse_tensor.SparseTensor(
|
||||||
|
indices=constant_op.constant([], shape=(0, 2), dtype=dtypes.int64),
|
||||||
|
values=constant_op.constant([], shape=(0,), dtype=dtypes.int32),
|
||||||
|
dense_shape=[5, 5])
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=False) as sess:
|
||||||
|
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
|
||||||
|
"batch_size must not be 0"):
|
||||||
|
sess.run(ctc_ops.ctc_loss(labels, inputs, sequence_lengths))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user