mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Merge pull request #49987 from tensorflow/mm-fix-2.2
Fix r2.2 branch after cherrypicks
This commit is contained in:
commit
ada627478c
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
#include <vector>
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/util/overflow.h"
|
||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
|
@ -64,17 +65,20 @@ class SparseSplitOp : public OpKernel {
|
|||
num_split_));
|
||||
|
||||
// Prevent overflow by constructing the dense shape separately
|
||||
TensorShape dense_shape;
|
||||
int64 total_elements = 1;
|
||||
const auto input_shape_flat = input_shape.flat<int64>();
|
||||
for (int i = 0; i < input_shape.NumElements(); i++) {
|
||||
OP_REQUIRES_OK(context,
|
||||
dense_shape.AddDimWithStatus(input_shape_flat(i)));
|
||||
total_elements =
|
||||
MultiplyWithoutOverflow(total_elements, input_shape_flat(i));
|
||||
OP_REQUIRES(context, total_elements >= 0,
|
||||
errors::Internal("Encountered overflow in dense shape"));
|
||||
}
|
||||
|
||||
sparse::SparseTensor sparse_tensor;
|
||||
OP_REQUIRES_OK(context,
|
||||
sparse::SparseTensor::Create(input_indices, input_values,
|
||||
dense_shape, &sparse_tensor));
|
||||
sparse::SparseTensor::Create(
|
||||
input_indices, input_values,
|
||||
TensorShape(input_shape.vec<int64>()), &sparse_tensor));
|
||||
|
||||
std::vector<sparse::SparseTensor> outputs;
|
||||
OP_REQUIRES_OK(context,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user