mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Fix r2.2 branch after cherrypicks
This commit is contained in:
parent
e522a1924b
commit
e9c91879c5
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/util/overflow.h"
|
||||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -64,17 +65,20 @@ class SparseSplitOp : public OpKernel {
|
||||||
num_split_));
|
num_split_));
|
||||||
|
|
||||||
// Prevent overflow by constructing the dense shape separately
|
// Prevent overflow by constructing the dense shape separately
|
||||||
TensorShape dense_shape;
|
int64 total_elements = 1;
|
||||||
const auto input_shape_flat = input_shape.flat<int64>();
|
const auto input_shape_flat = input_shape.flat<int64>();
|
||||||
for (int i = 0; i < input_shape.NumElements(); i++) {
|
for (int i = 0; i < input_shape.NumElements(); i++) {
|
||||||
OP_REQUIRES_OK(context,
|
total_elements =
|
||||||
dense_shape.AddDimWithStatus(input_shape_flat(i)));
|
MultiplyWithoutOverflow(total_elements, input_shape_flat(i));
|
||||||
|
OP_REQUIRES(context, total_elements >= 0,
|
||||||
|
errors::Internal("Encountered overflow in dense shape"));
|
||||||
}
|
}
|
||||||
|
|
||||||
sparse::SparseTensor sparse_tensor;
|
sparse::SparseTensor sparse_tensor;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
sparse::SparseTensor::Create(input_indices, input_values,
|
sparse::SparseTensor::Create(
|
||||||
dense_shape, &sparse_tensor));
|
input_indices, input_values,
|
||||||
|
TensorShape(input_shape.vec<int64>()), &sparse_tensor));
|
||||||
|
|
||||||
std::vector<sparse::SparseTensor> outputs;
|
std::vector<sparse::SparseTensor> outputs;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user