From d0c360576bd1492dc16004f28f8e066b5d82aaa3 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Wed, 2 Jun 2021 13:43:21 -0700 Subject: [PATCH] Fix r2.2 branch after cherrypicks --- tensorflow/core/kernels/sparse_split_op.cc | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/kernels/sparse_split_op.cc b/tensorflow/core/kernels/sparse_split_op.cc index 3b88a8ca2bf..ece6e832e86 100644 --- a/tensorflow/core/kernels/sparse_split_op.cc +++ b/tensorflow/core/kernels/sparse_split_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/util/overflow.h" #include "tensorflow/core/util/sparse/sparse_tensor.h" namespace tensorflow { @@ -64,17 +65,20 @@ class SparseSplitOp : public OpKernel { "), got ", num_split_)); // Prevent overflow by constructing the dense shape separately - TensorShape dense_shape; + int64 total_elements = 1; const auto input_shape_flat = input_shape.flat(); for (int i = 0; i < input_shape.NumElements(); i++) { - OP_REQUIRES_OK(context, - dense_shape.AddDimWithStatus(input_shape_flat(i))); + total_elements = + MultiplyWithoutOverflow(total_elements, input_shape_flat(i)); + OP_REQUIRES(context, total_elements >= 0, + errors::Internal("Encountered overflow in dense shape")); } sparse::SparseTensor sparse_tensor; OP_REQUIRES_OK(context, - sparse::SparseTensor::Create(input_indices, input_values, - dense_shape, &sparse_tensor)); + sparse::SparseTensor::Create( + input_indices, input_values, + TensorShape(input_shape.vec()), &sparse_tensor)); std::vector outputs; OP_REQUIRES_OK(context, sparse::SparseTensor::Split(