Merge pull request #49989 from tensorflow/mm-fix-2.4

Fix r2.4 branch after cherrypicks
This commit is contained in:
Mihai Maruseac 2021-06-02 14:51:53 -07:00 committed by GitHub
commit 3ee7cc1593
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/util/overflow.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
namespace tensorflow {
@ -64,17 +65,20 @@ class SparseSplitOp : public OpKernel {
"), got ", num_split_));
// Prevent overflow by constructing the dense shape separately
TensorShape dense_shape;
int64 total_elements = 1;
const auto input_shape_flat = input_shape.flat<int64>();
for (int i = 0; i < input_shape.NumElements(); i++) {
OP_REQUIRES_OK(context,
dense_shape.AddDimWithStatus(input_shape_flat(i)));
total_elements =
MultiplyWithoutOverflow(total_elements, input_shape_flat(i));
OP_REQUIRES(context, total_elements >= 0,
errors::Internal("Encountered overflow in dense shape"));
}
sparse::SparseTensor sparse_tensor;
OP_REQUIRES_OK(context,
sparse::SparseTensor::Create(input_indices, input_values,
dense_shape, &sparse_tensor));
sparse::SparseTensor::Create(
input_indices, input_values,
TensorShape(input_shape.vec<int64>()), &sparse_tensor));
std::vector<sparse::SparseTensor> outputs;
OP_REQUIRES_OK(context, sparse::SparseTensor::Split<T>(