mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Fix tf.raw_ops.SparseCross failing CHECK.
This commit is contained in:
parent
5cf71e2d85
commit
a6eaf1d55a
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
|
|
@ -295,6 +296,7 @@ class SparseCrossOp : public OpKernel {
|
|||
int64 signed_hash_key_;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
|
||||
hash_key_ = static_cast<uint64>(signed_hash_key_);
|
||||
OP_REQUIRES_OK(context, context->GetAttr("internal_type", &internal_type_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
|
|
@ -308,8 +310,10 @@ class SparseCrossOp : public OpKernel {
|
|||
OP_REQUIRES_OK(context,
|
||||
context->input_list("dense_inputs", &dense_list_in));
|
||||
|
||||
OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in,
|
||||
shapes_list_in, dense_list_in));
|
||||
DataType internal_type = internal_type_;
|
||||
OP_REQUIRES_OK(
|
||||
context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
|
||||
dense_list_in, internal_type));
|
||||
|
||||
std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
|
||||
GenerateColumnsFromInput(indices_list_in, values_list_in,
|
||||
|
|
@ -353,10 +357,19 @@ class SparseCrossOp : public OpKernel {
|
|||
Status ValidateInput(const OpInputList& indices_list_in,
|
||||
const OpInputList& values_list_in,
|
||||
const OpInputList& shapes_list_in,
|
||||
const OpInputList& dense_list_in) {
|
||||
const OpInputList& dense_list_in,
|
||||
const DataType& internal_type) {
|
||||
const auto size = indices_list_in.size();
|
||||
// Only perform internal_type check for SparseCrossOp.
|
||||
// Check if the internal_type is not invalid before doing so.
|
||||
bool check_type = internal_type != DT_INVALID;
|
||||
// Validates indices_list_in OpInputList.
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (check_type && indices_list_in[i].dtype() != DT_INT64) {
|
||||
return errors::InvalidArgument("Input indices should be of type ",
|
||||
DT_INT64, " but received ",
|
||||
indices_list_in[i].dtype());
|
||||
}
|
||||
if (!TensorShapeUtils::IsMatrix(indices_list_in[i].shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"Input indices should be a matrix but received shape ",
|
||||
|
|
@ -375,6 +388,14 @@ class SparseCrossOp : public OpKernel {
|
|||
values_list_in.size());
|
||||
}
|
||||
for (int i = 0; i < size; i++) {
|
||||
// Make sure to avoid the expected type to be string, but input values to be
|
||||
// int64.
|
||||
if (check_type && internal_type == DT_STRING &&
|
||||
values_list_in[i].dtype() == DT_INT64) {
|
||||
return errors::InvalidArgument("Input values should be of internal type ",
|
||||
internal_type, " but received ",
|
||||
values_list_in[i].dtype());
|
||||
}
|
||||
if (!TensorShapeUtils::IsVector(values_list_in[i].shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"Input values should be a vector but received shape ",
|
||||
|
|
@ -395,6 +416,11 @@ class SparseCrossOp : public OpKernel {
|
|||
shapes_list_in.size());
|
||||
}
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (check_type && shapes_list_in[i].dtype() != DT_INT64) {
|
||||
return errors::InvalidArgument("Input shape should be of type ", DT_INT64,
|
||||
" but received ",
|
||||
shapes_list_in[i].dtype());
|
||||
}
|
||||
if (!TensorShapeUtils::IsVector(shapes_list_in[i].shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"Input shapes should be a vector but received shape ",
|
||||
|
|
@ -410,6 +436,14 @@ class SparseCrossOp : public OpKernel {
|
|||
|
||||
// Validates dense_list_in OpInputList
|
||||
for (int i = 0; i < dense_list_in.size(); ++i) {
|
||||
// Make sure to avoid the expected type to be string, but input values to be
|
||||
// int64.
|
||||
if (check_type && internal_type == DT_STRING &&
|
||||
dense_list_in[i].dtype() == DT_INT64) {
|
||||
return errors::InvalidArgument("Dense inputs should be of internal type ",
|
||||
internal_type, " but received ",
|
||||
dense_list_in[i].dtype());
|
||||
}
|
||||
if (!TensorShapeUtils::IsMatrix(dense_list_in[i].shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"Dense inputs should be a matrix but received shape ",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user