Fix tf.raw_ops.SparseCross failing CHECK.

PiperOrigin-RevId: 368701671
Change-Id: Id805729dd9ba0bda36e4bb309408129b55fb649d
This commit is contained in:
Amit Patankar 2021-04-15 13:03:19 -07:00 committed by Geeta Chavan
parent 4b2ace809b
commit c664ac88cd

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/fingerprint.h"
@ -460,10 +461,19 @@ int64 CalculateBatchSize(const OpInputList& shapes_list_in,
Status ValidateInput(const OpInputList& indices_list_in, Status ValidateInput(const OpInputList& indices_list_in,
const OpInputList& values_list_in, const OpInputList& values_list_in,
const OpInputList& shapes_list_in, const OpInputList& shapes_list_in,
const OpInputList& dense_list_in) { const OpInputList& dense_list_in,
const DataType& internal_type) {
const auto size = indices_list_in.size(); const auto size = indices_list_in.size();
// Only perform internal_type check for SparseCrossOp.
// Check if the internal_type is not invalid before doing so.
bool check_type = internal_type != DT_INVALID;
// Validates indices_list_in OpInputList. // Validates indices_list_in OpInputList.
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
if (check_type && indices_list_in[i].dtype() != DT_INT64) {
return errors::InvalidArgument("Input indices should be of type ",
DT_INT64, " but received ",
indices_list_in[i].dtype());
}
if (!TensorShapeUtils::IsMatrix(indices_list_in[i].shape())) { if (!TensorShapeUtils::IsMatrix(indices_list_in[i].shape())) {
return errors::InvalidArgument( return errors::InvalidArgument(
"Input indices should be a matrix but received shape ", "Input indices should be a matrix but received shape ",
@ -482,6 +492,14 @@ Status ValidateInput(const OpInputList& indices_list_in,
values_list_in.size()); values_list_in.size());
} }
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
// Make sure to avoid the expected type to be string, but input values to be
// int64.
if (check_type && internal_type == DT_STRING &&
values_list_in[i].dtype() == DT_INT64) {
return errors::InvalidArgument("Input values should be of internal type ",
internal_type, " but received ",
values_list_in[i].dtype());
}
if (!TensorShapeUtils::IsVector(values_list_in[i].shape())) { if (!TensorShapeUtils::IsVector(values_list_in[i].shape())) {
return errors::InvalidArgument( return errors::InvalidArgument(
"Input values should be a vector but received shape ", "Input values should be a vector but received shape ",
@ -502,6 +520,11 @@ Status ValidateInput(const OpInputList& indices_list_in,
shapes_list_in.size()); shapes_list_in.size());
} }
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
if (check_type && shapes_list_in[i].dtype() != DT_INT64) {
return errors::InvalidArgument("Input shape should be of type ", DT_INT64,
" but received ",
shapes_list_in[i].dtype());
}
if (!TensorShapeUtils::IsVector(shapes_list_in[i].shape())) { if (!TensorShapeUtils::IsVector(shapes_list_in[i].shape())) {
return errors::InvalidArgument( return errors::InvalidArgument(
"Input shapes should be a vector but received shape ", "Input shapes should be a vector but received shape ",
@ -517,6 +540,14 @@ Status ValidateInput(const OpInputList& indices_list_in,
// Validates dense_list_in OpInputList // Validates dense_list_in OpInputList
for (int i = 0; i < dense_list_in.size(); ++i) { for (int i = 0; i < dense_list_in.size(); ++i) {
// Make sure to avoid the expected type to be string, but input values to be
// int64.
if (check_type && internal_type == DT_STRING &&
dense_list_in[i].dtype() == DT_INT64) {
return errors::InvalidArgument("Dense inputs should be of internal type ",
internal_type, " but received ",
dense_list_in[i].dtype());
}
if (!TensorShapeUtils::IsMatrix(dense_list_in[i].shape())) { if (!TensorShapeUtils::IsMatrix(dense_list_in[i].shape())) {
return errors::InvalidArgument( return errors::InvalidArgument(
"Dense inputs should be a matrix but received shape ", "Dense inputs should be a matrix but received shape ",
@ -698,6 +729,7 @@ class SparseCrossOp : public OpKernel {
int64 signed_hash_key_; int64 signed_hash_key_;
OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_)); OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
hash_key_ = static_cast<uint64>(signed_hash_key_); hash_key_ = static_cast<uint64>(signed_hash_key_);
OP_REQUIRES_OK(context, context->GetAttr("internal_type", &internal_type_));
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
@ -711,8 +743,10 @@ class SparseCrossOp : public OpKernel {
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->input_list("dense_inputs", &dense_list_in)); context->input_list("dense_inputs", &dense_list_in));
OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in, DataType internal_type = internal_type_;
shapes_list_in, dense_list_in)); OP_REQUIRES_OK(
context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
dense_list_in, internal_type));
std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns = std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
GenerateColumnsFromInput<InternalType>(indices_list_in, values_list_in, GenerateColumnsFromInput<InternalType>(indices_list_in, values_list_in,
@ -756,6 +790,7 @@ class SparseCrossOp : public OpKernel {
private: private:
int64 num_buckets_; int64 num_buckets_;
uint64 hash_key_; uint64 hash_key_;
DataType internal_type_;
}; };
class SparseCrossV2Op : public OpKernel { class SparseCrossV2Op : public OpKernel {
@ -773,8 +808,11 @@ class SparseCrossV2Op : public OpKernel {
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->input_list("dense_inputs", &dense_list_in)); context->input_list("dense_inputs", &dense_list_in));
OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in, // Set internal_type to invalid_type so that the check will be ignored.
shapes_list_in, dense_list_in)); DataType internal_type = DT_INVALID;
OP_REQUIRES_OK(
context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
dense_list_in, internal_type));
const Tensor* sep_t; const Tensor* sep_t;
OP_REQUIRES_OK(context, context->input("sep", &sep_t)); OP_REQUIRES_OK(context, context->input("sep", &sep_t));
@ -832,8 +870,11 @@ class SparseCrossHashedOp : public OpKernel {
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->input_list("dense_inputs", &dense_list_in)); context->input_list("dense_inputs", &dense_list_in));
OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in, // Set internal_type to invalid_type so that the check will be ignored.
shapes_list_in, dense_list_in)); DataType internal_type = DT_INVALID;
OP_REQUIRES_OK(
context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
dense_list_in, internal_type));
const Tensor* num_buckets_t; const Tensor* num_buckets_t;
OP_REQUIRES_OK(context, context->input("num_buckets", &num_buckets_t)); OP_REQUIRES_OK(context, context->input("num_buckets", &num_buckets_t));