mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Merge pull request #50030 from tensorflow/mm-cp-on-r2.4
Validate that a and b are proper sparse tensors
This commit is contained in:
commit
5cc80545fa
|
|
@ -150,6 +150,7 @@ class SparseSparseBinaryOpShared : public OpKernel {
|
|||
|
||||
const int64 a_nnz = a_indices_t->dim_size(0);
|
||||
const int64 b_nnz = b_indices_t->dim_size(0);
|
||||
|
||||
const auto a_values = a_values_t->vec<T>();
|
||||
const auto b_values = b_values_t->vec<T>();
|
||||
|
||||
|
|
@ -166,6 +167,14 @@ class SparseSparseBinaryOpShared : public OpKernel {
|
|||
"Input shapes should be a vector but received shapes ",
|
||||
a_shape_t->shape().DebugString(), " and ",
|
||||
b_shape_t->shape().DebugString()));
|
||||
const int num_dims = a_indices_t->dim_size(1);
|
||||
OP_REQUIRES(
|
||||
ctx, a_shape_t->NumElements() == num_dims,
|
||||
errors::InvalidArgument("Second dimension of a_indices and length of "
|
||||
"a_shape must match, got ",
|
||||
num_dims, " and ", a_shape_t->NumElements()));
|
||||
OP_REQUIRES(ctx, num_dims > 0,
|
||||
errors::InvalidArgument("Tensors must not be empty"));
|
||||
OP_REQUIRES(ctx, a_shape_t->IsSameSize(*b_shape_t),
|
||||
errors::InvalidArgument(
|
||||
"Operands do not have the same ranks; got shapes: ",
|
||||
|
|
@ -180,12 +189,6 @@ class SparseSparseBinaryOpShared : public OpKernel {
|
|||
" for dimension ", i));
|
||||
}
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, a_indices_t->dim_size(1) == b_indices_t->dim_size(1),
|
||||
errors::InvalidArgument(
|
||||
"Indices' dimensions do not match: got ", a_indices_t->dim_size(1),
|
||||
" and ", b_indices_t->dim_size(1), " for the second dimension."));
|
||||
const int num_dims = a_indices_t->dim_size(1);
|
||||
const auto a_indices_mat = a_indices_t->matrix<int64>();
|
||||
const auto b_indices_mat = b_indices_t->matrix<int64>();
|
||||
std::vector<T> a_augmented_values, b_augmented_values;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user