mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Support int64 for mul, this will be needed for i64 tensor_scatter_updates.
This also enables support for squared_difference for i64 PiperOrigin-RevId: 381160512 Change-Id: I42a4cb34aa510d2d886afd75c5636bd421c3485b
This commit is contained in:
parent
a3b7535ec5
commit
94ee5b9c8f
|
|
@ -81,6 +81,7 @@
|
|||
* `tf.lite`:
|
||||
* The recommended Android NDK version for building TensorFlow Lite has
|
||||
been changed from r18b to r19c.
|
||||
* Supports int64 for mul.
|
||||
* `tf.saved_model`:
|
||||
* SavedModels can now save custom gradients. Use the option
|
||||
`tf.saved_model.SaveOption(experimental_custom_gradients=True)` to
|
||||
|
|
|
|||
|
|
@ -248,10 +248,10 @@ bool VerifyMulOpShapeConstraints(MulOp op) {
|
|||
/*max_bcast_rank=*/4);
|
||||
}
|
||||
|
||||
// Allows I32, QI16 and F32 outputs when the operands have valid shapes, which
|
||||
// are broadcastable shapes up to four dimension or have same shapes.
|
||||
if (IsI32Type(element_type) || IsQI16Type(element_type) ||
|
||||
element_type.isF32()) {
|
||||
// Allows I32, I64, QI16 and F32 outputs when the operands have valid shapes,
|
||||
// which are broadcastable shapes up to four dimension or have same shapes.
|
||||
if (IsI32Type(element_type) || IsI64Type(element_type) ||
|
||||
IsQI16Type(element_type) || element_type.isF32()) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/4);
|
||||
|
|
|
|||
|
|
@ -2338,11 +2338,11 @@ def TFL_MulOp : TFL_Op<"mul", [
|
|||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
|
||||
ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$rhs,
|
||||
TFL_AFAttr:$fused_activation_function);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$output);
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
|
|
|
|||
|
|
@ -2149,3 +2149,12 @@ func @conv3d_transpose_unsupported_strides(%arg0: tensor<2x5x6x8x2xf32>, %arg1:
|
|||
// CHECK-LABEL: conv3d_transpose_unsupported_strides
|
||||
// CHECK: "tf.Conv3DBackpropInputV2"
|
||||
}
|
||||
|
||||
func @mul_i64(%arg0: tensor<14xi64>, %arg1: tensor<14xi64>) -> tensor<14xi64> {
|
||||
%0 = "tf.Mul"(%arg0, %arg1) : (tensor<14xi64>, tensor<14xi64>) -> tensor<14xi64>
|
||||
return %0: tensor<14xi64>
|
||||
|
||||
// CHECK-LABEL: mul_i64
|
||||
// CHECK: tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<14xi64>
|
||||
// CHECK: return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,7 +89,6 @@ MLIR_CONVERTER_KNOWN_BUGS = {
|
|||
# int64.
|
||||
r"div.*dtype=tf\.int64": "119126484",
|
||||
r"floor_div.*dtype=tf\.int64": "119126484",
|
||||
r"mul.*dtype=tf\.int64": "119126484",
|
||||
r"relu.*dtype=tf\.int64": "119126484",
|
||||
r"squared_difference.*dtype=tf\.int64": "119126484",
|
||||
# Post-training quantization support missing for below op in mlir.
|
||||
|
|
|
|||
|
|
@ -74,9 +74,7 @@ const std::map<string, string>& GetKnownBrokenTests() {
|
|||
{R"(^\/where.*1,2,3,1)", "134692786"},
|
||||
|
||||
{R"(^\/div.*dtype=tf\.int64)", "119126484"},
|
||||
{R"(^\/mul.*dtype=tf\.int64)", "119126484"},
|
||||
{R"(^\/floor_div.*dtype=tf\.int64)", "119126484"},
|
||||
{R"(^\/squared_difference.*dtype=tf\.int64)", "119126484"},
|
||||
});
|
||||
return *kBrokenTests;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def make_tensor_scatter_update_tests(options):
|
|||
"""Make a set of tests to do tensor_scatter_update."""
|
||||
|
||||
test_parameters = [{
|
||||
"input_dtype": [tf.float32, tf.int32],
|
||||
"input_dtype": [tf.float32, tf.int32, tf.int64],
|
||||
"input_shape": [[14], [2, 4, 7]],
|
||||
"updates_count": [1, 3, 5],
|
||||
}]
|
||||
|
|
|
|||
|
|
@ -157,6 +157,12 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
|
|||
TF_LITE_MUL(optimized_ops, Mul, float);
|
||||
}
|
||||
}
|
||||
} else if (output->type == kTfLiteInt64) {
|
||||
if (need_broadcast) {
|
||||
TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int64_t);
|
||||
} else {
|
||||
TF_LITE_MUL(reference_ops, Mul, int64_t);
|
||||
}
|
||||
}
|
||||
#undef TF_LITE_MUL
|
||||
}
|
||||
|
|
@ -275,7 +281,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||
TF_LITE_ENSURE_OK(context,
|
||||
GetOutputSafe(context, node, kOutputTensor, &output));
|
||||
|
||||
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
|
||||
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32 ||
|
||||
output->type == kTfLiteInt64) {
|
||||
EvalMul<kernel_type>(context, node, params, data, input1, input2, output);
|
||||
} else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
|
||||
output->type == kTfLiteInt16) {
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||
/* min_version = */ 1,
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_MUL, Register_MUL(), /* min_version = */ 1,
|
||||
/* max_version = */ 4);
|
||||
/* max_version = */ 5);
|
||||
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
|
|
|
|||
|
|
@ -265,7 +265,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
|||
/* min_version = */ 1,
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_MUL, Register_MUL_REF(), /* min_version = */ 1,
|
||||
/* max_version = */ 4);
|
||||
/* max_version = */ 5);
|
||||
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2NORM_REF(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 2);
|
||||
|
|
|
|||
|
|
@ -101,7 +101,6 @@ const std::map<string, string>& GetKnownBrokenTests() {
|
|||
{R"(^\/floor_mod.*activation=True.*dtype=tf\.int64)", "112968789"},
|
||||
|
||||
{R"(^\/div.*dtype=tf\.int64)", "119126484"},
|
||||
{R"(^\/mul.*dtype=tf\.int64)", "119126484"},
|
||||
{R"(^\/floor_div.*dtype=tf\.int64)", "119126484"},
|
||||
{R"(^\/squared_difference.*dtype=tf\.int64)", "119126484"},
|
||||
});
|
||||
|
|
|
|||
|
|
@ -228,6 +228,10 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||
}
|
||||
|
||||
case BuiltinOperator_MUL:
|
||||
// Version 5 supports int64 inputs
|
||||
if (op_sig.inputs.at(0).type == kTfLiteInt64) {
|
||||
return 5;
|
||||
}
|
||||
// Version 4 supports int16 inputs
|
||||
if (op_sig.inputs.at(0).type == kTfLiteInt16) {
|
||||
return 4;
|
||||
|
|
|
|||
|
|
@ -432,6 +432,13 @@ TEST(OpVersionTest, VersioningSubTest) {
|
|||
SimpleVersioningTest(BuiltinOperator_SUB);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningMUL5Test) {
|
||||
OpSignature fake_op_sig;
|
||||
fake_op_sig.op = BuiltinOperator_MUL;
|
||||
fake_op_sig.inputs = CreateOpSignatureTensorSpecs(kTfLiteInt64);
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningSub4Test) {
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_SUB,
|
||||
|
|
|
|||
|
|
@ -146,6 +146,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
|||
{{BuiltinOperator_MUL, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_MUL, 3}, "1.15.0"},
|
||||
{{BuiltinOperator_MUL, 4}, "2.3.0"},
|
||||
{{BuiltinOperator_MUL, 5}, "2.6.0"},
|
||||
{{BuiltinOperator_NON_MAX_SUPPRESSION_V4, 1}, "2.1.0"},
|
||||
{{BuiltinOperator_NON_MAX_SUPPRESSION_V5, 1}, "2.1.0"},
|
||||
{{BuiltinOperator_PAD, 1}, "1.5.0"},
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user