Support int64 for mul, this will be needed for i64 tensor_scatter_updates.

This also enables support for squared_difference for i64

PiperOrigin-RevId: 381160512
Change-Id: I42a4cb34aa510d2d886afd75c5636bd421c3485b
This commit is contained in:
Renjie Liu 2021-06-23 18:53:31 -07:00 committed by TensorFlower Gardener
parent a3b7535ec5
commit 94ee5b9c8f
14 changed files with 40 additions and 15 deletions

View File

@ -81,6 +81,7 @@
* `tf.lite`:
* The recommended Android NDK version for building TensorFlow Lite has
been changed from r18b to r19c.
* Supports int64 for mul.
* `tf.saved_model`:
* SavedModels can now save custom gradients. Use the option
`tf.saved_model.SaveOption(experimental_custom_gradients=True)` to

View File

@ -248,10 +248,10 @@ bool VerifyMulOpShapeConstraints(MulOp op) {
/*max_bcast_rank=*/4);
}
// Allows I32, QI16 and F32 outputs when the operands have valid shapes, which
// are broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type) || IsQI16Type(element_type) ||
element_type.isF32()) {
// Allows I32, I64, QI16 and F32 outputs when the operands have valid shapes,
// which are broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type) || IsI64Type(element_type) ||
IsQI16Type(element_type) || element_type.isF32()) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);

View File

@ -2338,11 +2338,11 @@ def TFL_MulOp : TFL_Op<"mul", [
}];
let arguments = (
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$rhs,
TFL_AFAttr:$fused_activation_function);
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$output);
let hasFolder = 1;

View File

@ -2149,3 +2149,12 @@ func @conv3d_transpose_unsupported_strides(%arg0: tensor<2x5x6x8x2xf32>, %arg1:
// CHECK-LABEL: conv3d_transpose_unsupported_strides
// CHECK: "tf.Conv3DBackpropInputV2"
}
func @mul_i64(%arg0: tensor<14xi64>, %arg1: tensor<14xi64>) -> tensor<14xi64> {
%0 = "tf.Mul"(%arg0, %arg1) : (tensor<14xi64>, tensor<14xi64>) -> tensor<14xi64>
return %0: tensor<14xi64>
// CHECK-LABEL: mul_i64
// CHECK: tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<14xi64>
// CHECK: return
}

View File

@ -89,7 +89,6 @@ MLIR_CONVERTER_KNOWN_BUGS = {
# int64.
r"div.*dtype=tf\.int64": "119126484",
r"floor_div.*dtype=tf\.int64": "119126484",
r"mul.*dtype=tf\.int64": "119126484",
r"relu.*dtype=tf\.int64": "119126484",
r"squared_difference.*dtype=tf\.int64": "119126484",
# Post-training quantization support missing for below op in mlir.

View File

@ -74,9 +74,7 @@ const std::map<string, string>& GetKnownBrokenTests() {
{R"(^\/where.*1,2,3,1)", "134692786"},
{R"(^\/div.*dtype=tf\.int64)", "119126484"},
{R"(^\/mul.*dtype=tf\.int64)", "119126484"},
{R"(^\/floor_div.*dtype=tf\.int64)", "119126484"},
{R"(^\/squared_difference.*dtype=tf\.int64)", "119126484"},
});
return *kBrokenTests;
}

View File

@ -29,7 +29,7 @@ def make_tensor_scatter_update_tests(options):
"""Make a set of tests to do tensor_scatter_update."""
test_parameters = [{
"input_dtype": [tf.float32, tf.int32],
"input_dtype": [tf.float32, tf.int32, tf.int64],
"input_shape": [[14], [2, 4, 7]],
"updates_count": [1, 3, 5],
}]

View File

@ -157,6 +157,12 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
TF_LITE_MUL(optimized_ops, Mul, float);
}
}
} else if (output->type == kTfLiteInt64) {
if (need_broadcast) {
TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int64_t);
} else {
TF_LITE_MUL(reference_ops, Mul, int64_t);
}
}
#undef TF_LITE_MUL
}
@ -275,7 +281,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32 ||
output->type == kTfLiteInt64) {
EvalMul<kernel_type>(context, node, params, data, input1, input2, output);
} else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
output->type == kTfLiteInt16) {

View File

@ -100,7 +100,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_MUL, Register_MUL(), /* min_version = */ 1,
/* max_version = */ 4);
/* max_version = */ 5);
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION(),
/* min_version = */ 1,
/* max_version = */ 2);

View File

@ -265,7 +265,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
/* min_version = */ 1,
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_MUL, Register_MUL_REF(), /* min_version = */ 1,
/* max_version = */ 4);
/* max_version = */ 5);
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2NORM_REF(),
/* min_version = */ 1,
/* max_version = */ 2);

View File

@ -101,7 +101,6 @@ const std::map<string, string>& GetKnownBrokenTests() {
{R"(^\/floor_mod.*activation=True.*dtype=tf\.int64)", "112968789"},
{R"(^\/div.*dtype=tf\.int64)", "119126484"},
{R"(^\/mul.*dtype=tf\.int64)", "119126484"},
{R"(^\/floor_div.*dtype=tf\.int64)", "119126484"},
{R"(^\/squared_difference.*dtype=tf\.int64)", "119126484"},
});

View File

@ -228,6 +228,10 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
}
case BuiltinOperator_MUL:
// Version 5 supports int64 inputs
if (op_sig.inputs.at(0).type == kTfLiteInt64) {
return 5;
}
// Version 4 supports int16 inputs
if (op_sig.inputs.at(0).type == kTfLiteInt16) {
return 4;

View File

@ -432,6 +432,13 @@ TEST(OpVersionTest, VersioningSubTest) {
SimpleVersioningTest(BuiltinOperator_SUB);
}
TEST(OpVersionTest, VersioningMUL5Test) {
OpSignature fake_op_sig;
fake_op_sig.op = BuiltinOperator_MUL;
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt64);
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
}
TEST(OpVersionTest, VersioningSub4Test) {
OpSignature fake_op_sig = {
.op = BuiltinOperator_SUB,

View File

@ -146,6 +146,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_MUL, 2}, "1.14.0"},
{{BuiltinOperator_MUL, 3}, "1.15.0"},
{{BuiltinOperator_MUL, 4}, "2.3.0"},
{{BuiltinOperator_MUL, 5}, "2.6.0"},
{{BuiltinOperator_NON_MAX_SUPPRESSION_V4, 1}, "2.1.0"},
{{BuiltinOperator_NON_MAX_SUPPRESSION_V5, 1}, "2.1.0"},
{{BuiltinOperator_PAD, 1}, "1.5.0"},