Support TFLite UnsortedSegmentSum Op (converter changes)

PiperOrigin-RevId: 462472445
This commit is contained in:
Luke Boyer 2022-07-21 14:34:21 -07:00 committed by TensorFlower Gardener
parent de62f4f829
commit 59f1c50bf2
9 changed files with 156 additions and 92 deletions

View File

@ -40,6 +40,7 @@
* tf.einsum is supported with multiple unknown shapes.
* tf.unsortedsegmentprod op is supported.
* tf.unsortedsegmentmax op is supported.
* tf.unsortedsegmentsum op is supported.
* Updates to existing operations:
* tfl.scatter_nd now supports I1 for update arg.
* Upgrade Flatbuffers v2.0.5 from v1.12.0

View File

@ -4869,6 +4869,31 @@ def TFL_UnsortedSegmentMaxOp: TFL_Op<"unsorted_segment_max", [
let results = (outs TFL_TensorOf<[F32, I32]>:$output);
}
def TFL_UnsortedSegmentSumOp: TFL_Op<"unsorted_segment_sum", [
NoSideEffect, PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
let summary = "UnsortedSegmentSum operator";
let description = [{
From a tensor segmentation, computes the `output` resulting from
summing together elements mapped to the same segment_id. I.e. `output[i]` is
equal to the tensor sum of all elements from the input tensor mapped to
segment_id `i`. If no tensors are mapped to a particular included
segment_id, the output at that indice will be a zero tensor with the
appropriate shape. Note the values of segment_ids are always validated to be
less than num_segments and an error is thrown for out-of-bound indices
}];
let arguments = (ins
TFL_TensorOf<[F32, I32]>:$input,
TFL_I32Tensor:$segment_ids,
TFL_I32Tensor:$num_segments
);
let results = (outs TFL_TensorOf<[F32, I32]>:$output);
}
def TFL_YieldOp : Op<TFL_Dialect, "yield",
[NoSideEffect,
Terminator,

View File

@ -2082,6 +2082,41 @@ func.func @unsorted_segment_max_i64(%arg0: tensor<9xf32>, %arg1: tensor<9xi64>)
// -----
func.func @unsorted_segment_sum(%arg0: tensor<8xf32>, %arg1: tensor<8xi32>) -> tensor<8xf32> {
%num_segments = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
%0 = "tf.UnsortedSegmentSum"(%arg0, %arg1, %num_segments) : (tensor<8xf32>, tensor<8xi32>, tensor<i32>) -> tensor<8xf32>
func.return %0 : tensor<8xf32>
// CHECK-LABEL: unsorted_segment_sum
// CHECK: %[[CST:.*]] = arith.constant dense<8> : tensor<i32>
// CHECK: %[[BCT:.*]] = "tfl.unsorted_segment_sum"(%arg0, %arg1, %[[CST]]) : (tensor<8xf32>, tensor<8xi32>, tensor<i32>) -> tensor<8xf32>
// CHECK: return %[[BCT]] : tensor<8xf32>
}
// -----
func.func @unsorted_segment_sum_3arg(%arg0: tensor<5xi32>, %arg1: tensor<5xi32>, %arg2: tensor<i64>) -> tensor<5xi32>{
%0 = "tf.UnsortedSegmentSum"(%arg0, %arg1, %arg2) : (tensor<5xi32>, tensor<5xi32>, tensor<i64>) -> tensor<5xi32>
func.return %0 : tensor<5xi32>
// CHECK-LABEL: unsorted_segment_sum_3arg
// CHECK: %[[BCT:.*]] = "tfl.cast"(%arg2) : (tensor<i64>) -> tensor<i32>
// CHECK: %[[RES:.*]] = "tfl.unsorted_segment_sum"(%arg0, %arg1, %[[BCT]]) : (tensor<5xi32>, tensor<5xi32>, tensor<i32>) -> tensor<5xi32>
// CHECK: return %[[RES]] : tensor<5xi32>
}
// -----
func.func @unsorted_segment_sum_i64(%arg0: tensor<9xf32>, %arg1: tensor<9xi64>) -> tensor<9xf32> {
%num_segments = "tf.Const"() {value = dense<9> : tensor<i32>} : () -> tensor<i32>
%0 = "tf.UnsortedSegmentSum"(%arg0, %arg1, %num_segments) : (tensor<9xf32>, tensor<9xi64>, tensor<i32>) -> tensor<9xf32>
func.return %0 : tensor<9xf32>
// CHECK-LABEL: unsorted_segment_sum_i64
// CHECK: %[[CST:.*]] = arith.constant dense<9> : tensor<i32>
// CHECK: %[[CAST:.*]] = "tfl.cast"(%arg1) : (tensor<9xi64>) -> tensor<9xi32>
// CHECK: %[[RES:.*]] = "tfl.unsorted_segment_sum"(%arg0, %[[CAST]], %[[CST]]) : (tensor<9xf32>, tensor<9xi32>, tensor<i32>) -> tensor<9xf32>
// CHECK: return %[[RES]] : tensor<9xf32>
}
// -----
func.func @rfft2d(%arg0: tensor<10x20x10x30xf32>, %arg1: tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>> {
%0 = "tf.RFFT2D"(%arg0, %arg1) : (tensor<10x20x10x30xf32>, tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>>

View File

@ -2988,3 +2988,30 @@ func.func @scatter_nd_i1(%arg0: tensor<?xi32>, %arg1: tensor<?xi1>, %arg2: tenso
// -----
// CHECK-LABEL: testUnsortedSegmentSum
func.func @testUnsortedSegmentSum(%arg0: tensor<8xf32>, %arg1: tensor<8xi32>, %arg2: tensor<i32>) -> tensor<8xf32> {
// CHECK: "tfl.unsorted_segment_sum"(%arg0, %arg1, %arg2)
%0 = "tfl.unsorted_segment_sum"(%arg0, %arg1, %arg2) : (tensor<8xf32>, tensor<8xi32>, tensor<i32>) -> tensor<8xf32>
func.return %0 : tensor<8xf32>
// CHECK: return %0 : tensor<8xf32>
}
// -----
// CHECK-LABEL: testUnsortedSegmentProd
func.func @testUnsortedSegmentProd(%arg0: tensor<8xf32>, %arg1: tensor<8xi32>, %arg2: tensor<i32>) -> tensor<8xf32> {
// CHECK: "tfl.unsorted_segment_prod"(%arg0, %arg1, %arg2)
%0 = "tfl.unsorted_segment_prod"(%arg0, %arg1, %arg2) : (tensor<8xf32>, tensor<8xi32>, tensor<i32>) -> tensor<8xf32>
func.return %0 : tensor<8xf32>
// CHECK: return %0 : tensor<8xf32>
}
// -----
// CHECK-LABEL: testUnsortedSegmentMax
func.func @testUnsortedSegmentMax(%arg0: tensor<8xf32>, %arg1: tensor<8xi32>, %arg2: tensor<i32>) -> tensor<8xf32> {
// CHECK: "tfl.unsorted_segment_max"(%arg0, %arg1, %arg2)
%0 = "tfl.unsorted_segment_max"(%arg0, %arg1, %arg2) : (tensor<8xf32>, tensor<8xi32>, tensor<i32>) -> tensor<8xf32>
func.return %0 : tensor<8xf32>
// CHECK: return %0 : tensor<8xf32>
}

View File

@ -220,14 +220,15 @@ def LegalizeSqrt : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
def LegalizeSquare : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, $segment_ids),
(TFL_SegmentSumOp $data, (CreateTFCastToInt32Op $segment_ids))>;
def LegalizeUnsortedSegmentProd :
Pat<(TF_UnsortedSegmentProdOp $data, $segment_ids, $num_segments),
(TFL_UnsortedSegmentProdOp $data, (CreateTFCastToInt32Op $segment_ids),
(CreateTFCastToInt32Op $num_segments))>;
def LegalizeUnsortedSegmentMax :
Pat<(TF_UnsortedSegmentMaxOp $data, $segment_ids, $num_segments),
(TFL_UnsortedSegmentMaxOp $data, (CreateTFCastToInt32Op $segment_ids),
(CreateTFCastToInt32Op $num_segments))>;
foreach UnsortedSegmentOp = [
[TF_UnsortedSegmentSumOp, TFL_UnsortedSegmentSumOp],
[TF_UnsortedSegmentMaxOp, TFL_UnsortedSegmentMaxOp],
[TF_UnsortedSegmentProdOp, TFL_UnsortedSegmentProdOp]] in {
def Legalize#UnsortedSegmentOp[0] :
Pat<(UnsortedSegmentOp[0] $data, $segment_ids, $num_segments),
(UnsortedSegmentOp[1] $data, (CreateTFCastToInt32Op $segment_ids),
(CreateTFCastToInt32Op $num_segments))>;
}
def LegalizeSelect : Pat<(TF_SelectOp $cond, $x, $y),
(TFL_SelectOp $cond, $x, $y)>;
def LegalizeSelectV2SameStaticShape : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y),

View File

@ -182,6 +182,7 @@ def generated_test_models():
"unroll_batch_matmul",
"unsorted_segment_max",
"unsorted_segment_prod",
"unsorted_segment_sum",
"where",
"where_v2",
"while",
@ -271,6 +272,7 @@ def generated_test_models_failing(conversion_mode, delegate):
"unique",
"unsorted_segment_max",
"unsorted_segment_prod",
"unsorted_segment_sum",
"where",
"where_v2",
"while",

View File

@ -186,8 +186,7 @@ from tensorflow.lite.testing.op_tests.unfused_gru import make_unfused_gru_tests
from tensorflow.lite.testing.op_tests.unique import make_unique_tests
from tensorflow.lite.testing.op_tests.unpack import make_unpack_tests
from tensorflow.lite.testing.op_tests.unroll_batch_matmul import make_unroll_batch_matmul_tests
from tensorflow.lite.testing.op_tests.unsorted_segment_max import make_unsorted_segment_max_tests
from tensorflow.lite.testing.op_tests.unsorted_segment_prod import make_unsorted_segment_prod_tests
from tensorflow.lite.testing.op_tests.unsorted_segment import make_unsorted_segment_max_tests, make_unsorted_segment_prod_tests, make_unsorted_segment_sum_tests
from tensorflow.lite.testing.op_tests.where import make_where_tests
from tensorflow.lite.testing.op_tests.where_v2 import make_where_v2_tests
from tensorflow.lite.testing.op_tests.while_loop import make_while_tests

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test configs for unsorted_segment_prod."""
"""Test configs for unsorted_segment ops."""
import tensorflow.compat.v1 as tf
from tensorflow.lite.testing.zip_test_utils import create_tensor_data
@ -20,29 +20,52 @@ from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
from tensorflow.lite.testing.zip_test_utils import register_make_test_function
@register_make_test_function()
def make_unsorted_segment_prod_tests(options):
"""Make a set of tests for unsorted_segment_prod op."""
def make_unsorted_segment_tests(options, unsorted_segment_op):
"""Make a set of tests for given unsorted_segment op."""
test_parameters = [{
"data": [[5]],
"data_shape": [[5]],
"segment_id": [[0, 1, 1, 0, 1]],
"num_segments": [2],
"dtype": [tf.int32, tf.float32],
"multi_node": [0]
}, {
"data": [[2, 3, 4], [2, 5, 2]],
"data_shape": [[2, 3, 4], [2, 5, 2]],
"segment_id": [[0, 1]],
"num_segments": [2],
"dtype": [tf.int32, tf.float32],
"multi_node": [0]
}, {
"data": [[4]],
"data_shape": [[4]],
"segment_id": [[0, 0, 1, 8]],
"num_segments": [9],
"dtype": [tf.int32, tf.float32],
"multi_node": [0]
}, {
"data": [[4]],
"data_shape": [[3]],
"segment_id": [[-1, -2, -1]],
"num_segments": [1],
"dtype": [tf.int32, tf.float32],
"multi_node": [0]
}, {
"data_shape": [[3]],
"segment_id": [[-1, 0, 1]],
"num_segments": [2],
"dtype": [tf.int32, tf.float32],
"multi_node": [0]
}, {
"data_shape": [[3, 2]],
"segment_id": [[-1, 0, 0]],
"num_segments": [1],
"dtype": [tf.int32, tf.float32],
"multi_node": [0]
}, {
"data_shape": [[3, 2]],
"segment_id": [[-1, -2, -1]],
"num_segments": [1],
"dtype": [tf.int32, tf.float32],
"multi_node": [0]
}, {
"data_shape": [[4]],
"segment_id_shape": [[4]],
"segment_id_min": [0],
"segment_id_max": [1],
@ -55,7 +78,7 @@ def make_unsorted_segment_prod_tests(options):
def build_graph_one_node(parameters):
data_tensor = tf.compat.v1.placeholder(
dtype=parameters["dtype"], name="data", shape=parameters["data"])
dtype=parameters["dtype"], name="data", shape=parameters["data_shape"])
segment_ids_tensor = tf.constant(
parameters["segment_id"], dtype=tf.int32, name="segment_ids")
num_segments_tensor = tf.constant(
@ -63,15 +86,15 @@ def make_unsorted_segment_prod_tests(options):
dtype=tf.int32,
shape=[],
name="num_segments")
output = tf.math.unsorted_segment_prod(data_tensor, segment_ids_tensor,
num_segments_tensor)
output = unsorted_segment_op(data_tensor, segment_ids_tensor,
num_segments_tensor)
return [data_tensor], [output]
# test cases for handling dynamically shaped input tensor
def build_graph_multi_node(parameters):
data_tensor = tf.compat.v1.placeholder(
dtype=parameters["dtype"], name="data", shape=parameters["data"])
dtype=parameters["dtype"], name="data", shape=parameters["data_shape"])
segment_ids_tensor = tf.compat.v1.placeholder(
dtype=tf.int32,
name="segment_ids",
@ -81,9 +104,8 @@ def make_unsorted_segment_prod_tests(options):
dtype=tf.int32,
shape=[],
name="num_segments")
intermediate_tensor = tf.math.unsorted_segment_prod(data_tensor,
segment_ids_tensor,
num_segments_tensor)
intermediate_tensor = unsorted_segment_op(data_tensor, segment_ids_tensor,
num_segments_tensor)
segment_ids_tensor_2 = tf.constant(
parameters["segment_id_2"], dtype=tf.int32, name="segment_ids_2")
num_segments_tensor_2 = tf.constant(
@ -91,9 +113,8 @@ def make_unsorted_segment_prod_tests(options):
dtype=tf.int32,
shape=[],
name="num_segments_2")
output = tf.math.unsorted_segment_prod(intermediate_tensor,
segment_ids_tensor_2,
num_segments_tensor_2)
output = unsorted_segment_op(intermediate_tensor, segment_ids_tensor_2,
num_segments_tensor_2)
return [data_tensor, segment_ids_tensor], [output]
def build_graph(parameters):
@ -105,13 +126,13 @@ def make_unsorted_segment_prod_tests(options):
def build_inputs_one_node(parameters, sess, inputs, outputs):
data_value = create_tensor_data(
parameters["dtype"], shape=parameters["data"])
parameters["dtype"], shape=parameters["data_shape"])
return [data_value], sess.run(
outputs, feed_dict=dict(zip(inputs, [data_value])))
def build_inputs_multi_node(parameters, sess, inputs, outputs):
data_value = create_tensor_data(
dtype=parameters["dtype"], shape=parameters["data"])
dtype=parameters["dtype"], shape=parameters["data_shape"])
segment_id_value = create_tensor_data(
dtype=tf.int32,
shape=parameters["segment_id_shape"],
@ -128,3 +149,18 @@ def make_unsorted_segment_prod_tests(options):
return build_inputs_one_node(parameters, sess, inputs, outputs)
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
@register_make_test_function()
def make_unsorted_segment_prod_tests(options):
make_unsorted_segment_tests(options, tf.math.unsorted_segment_prod)
@register_make_test_function()
def make_unsorted_segment_max_tests(options):
make_unsorted_segment_tests(options, tf.math.unsorted_segment_max)
@register_make_test_function()
def make_unsorted_segment_sum_tests(options):
make_unsorted_segment_tests(options, tf.math.unsorted_segment_sum)

View File

@ -1,62 +0,0 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test configs for unsorted_segment_max."""
import tensorflow.compat.v1 as tf
from tensorflow.lite.testing.zip_test_utils import create_tensor_data
from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
from tensorflow.lite.testing.zip_test_utils import register_make_test_function
@register_make_test_function()
def make_unsorted_segment_max_tests(options):
"""Make a set of tests for unsorted_segment_max op."""
test_parameters = [{
"data": [[5]],
"segment_id": [[0, 1, 1, 0, 1]],
"num_segments": [2],
"dtype": [tf.int32, tf.float32],
"segment_dtype": [tf.int32, tf.int64]
}, {
"data": [[2, 3, 4], [2, 5, 2]],
"segment_id": [[0, 1], [-1, -1]],
"num_segments": [2],
"dtype": [tf.int32, tf.float32],
"segment_dtype": [tf.int32, tf.int64]
}]
def build_graph(parameters):
data_tensor = tf.compat.v1.placeholder(
dtype=parameters["dtype"], name="data", shape=parameters["data"])
segment_ids_tensor = tf.constant(
parameters["segment_id"],
dtype=parameters["segment_dtype"],
name="segment_ids")
num_segments = tf.constant(
parameters["num_segments"],
dtype=parameters["segment_dtype"],
shape=[],
name="num_segments")
output = tf.math.unsorted_segment_max(data_tensor, segment_ids_tensor,
num_segments)
return [data_tensor], [output]
def build_inputs(parameters, sess, inputs, outputs):
data_value = create_tensor_data(
parameters["dtype"], shape=parameters["data"])
return [data_value], sess.run(
outputs, feed_dict=dict(zip(inputs, [data_value])))
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)