mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Support TFLite UnsortedSegmentSum Op (converter changes)
PiperOrigin-RevId: 462472445
This commit is contained in:
parent
de62f4f829
commit
59f1c50bf2
|
|
@ -40,6 +40,7 @@
|
|||
* tf.einsum is supported with multiple unknown shapes.
|
||||
* tf.unsortedsegmentprod op is supported.
|
||||
* tf.unsortedsegmentmax op is supported.
|
||||
* tf.unsortedsegmentsum op is supported.
|
||||
* Updates to existing operations:
|
||||
* tfl.scatter_nd now supports I1 for update arg.
|
||||
* Upgrade Flatbuffers v2.0.5 from v1.12.0
|
||||
|
|
|
|||
|
|
@ -4869,6 +4869,31 @@ def TFL_UnsortedSegmentMaxOp: TFL_Op<"unsorted_segment_max", [
|
|||
let results = (outs TFL_TensorOf<[F32, I32]>:$output);
|
||||
}
|
||||
|
||||
def TFL_UnsortedSegmentSumOp: TFL_Op<"unsorted_segment_sum", [
|
||||
NoSideEffect, PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
|
||||
|
||||
let summary = "UnsortedSegmentSum operator";
|
||||
|
||||
let description = [{
|
||||
From a tensor segmentation, computes the `output` resulting from
|
||||
summing together elements mapped to the same segment_id. I.e. `output[i]` is
|
||||
equal to the tensor sum of all elements from the input tensor mapped to
|
||||
segment_id `i`. If no tensors are mapped to a particular included
|
||||
segment_id, the output at that indice will be a zero tensor with the
|
||||
appropriate shape. Note the values of segment_ids are always validated to be
|
||||
less than num_segments and an error is thrown for out-of-bound indices
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I32]>:$input,
|
||||
TFL_I32Tensor:$segment_ids,
|
||||
TFL_I32Tensor:$num_segments
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I32]>:$output);
|
||||
}
|
||||
|
||||
def TFL_YieldOp : Op<TFL_Dialect, "yield",
|
||||
[NoSideEffect,
|
||||
Terminator,
|
||||
|
|
|
|||
|
|
@ -2082,6 +2082,41 @@ func.func @unsorted_segment_max_i64(%arg0: tensor<9xf32>, %arg1: tensor<9xi64>)
|
|||
|
||||
// -----
|
||||
|
||||
func.func @unsorted_segment_sum(%arg0: tensor<8xf32>, %arg1: tensor<8xi32>) -> tensor<8xf32> {
|
||||
%num_segments = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
|
||||
%0 = "tf.UnsortedSegmentSum"(%arg0, %arg1, %num_segments) : (tensor<8xf32>, tensor<8xi32>, tensor<i32>) -> tensor<8xf32>
|
||||
func.return %0 : tensor<8xf32>
|
||||
// CHECK-LABEL: unsorted_segment_sum
|
||||
// CHECK: %[[CST:.*]] = arith.constant dense<8> : tensor<i32>
|
||||
// CHECK: %[[BCT:.*]] = "tfl.unsorted_segment_sum"(%arg0, %arg1, %[[CST]]) : (tensor<8xf32>, tensor<8xi32>, tensor<i32>) -> tensor<8xf32>
|
||||
// CHECK: return %[[BCT]] : tensor<8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @unsorted_segment_sum_3arg(%arg0: tensor<5xi32>, %arg1: tensor<5xi32>, %arg2: tensor<i64>) -> tensor<5xi32>{
|
||||
%0 = "tf.UnsortedSegmentSum"(%arg0, %arg1, %arg2) : (tensor<5xi32>, tensor<5xi32>, tensor<i64>) -> tensor<5xi32>
|
||||
func.return %0 : tensor<5xi32>
|
||||
// CHECK-LABEL: unsorted_segment_sum_3arg
|
||||
// CHECK: %[[BCT:.*]] = "tfl.cast"(%arg2) : (tensor<i64>) -> tensor<i32>
|
||||
// CHECK: %[[RES:.*]] = "tfl.unsorted_segment_sum"(%arg0, %arg1, %[[BCT]]) : (tensor<5xi32>, tensor<5xi32>, tensor<i32>) -> tensor<5xi32>
|
||||
// CHECK: return %[[RES]] : tensor<5xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @unsorted_segment_sum_i64(%arg0: tensor<9xf32>, %arg1: tensor<9xi64>) -> tensor<9xf32> {
|
||||
%num_segments = "tf.Const"() {value = dense<9> : tensor<i32>} : () -> tensor<i32>
|
||||
%0 = "tf.UnsortedSegmentSum"(%arg0, %arg1, %num_segments) : (tensor<9xf32>, tensor<9xi64>, tensor<i32>) -> tensor<9xf32>
|
||||
func.return %0 : tensor<9xf32>
|
||||
// CHECK-LABEL: unsorted_segment_sum_i64
|
||||
// CHECK: %[[CST:.*]] = arith.constant dense<9> : tensor<i32>
|
||||
// CHECK: %[[CAST:.*]] = "tfl.cast"(%arg1) : (tensor<9xi64>) -> tensor<9xi32>
|
||||
// CHECK: %[[RES:.*]] = "tfl.unsorted_segment_sum"(%arg0, %[[CAST]], %[[CST]]) : (tensor<9xf32>, tensor<9xi32>, tensor<i32>) -> tensor<9xf32>
|
||||
// CHECK: return %[[RES]] : tensor<9xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @rfft2d(%arg0: tensor<10x20x10x30xf32>, %arg1: tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>> {
|
||||
%0 = "tf.RFFT2D"(%arg0, %arg1) : (tensor<10x20x10x30xf32>, tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>>
|
||||
|
|
|
|||
|
|
@ -2988,3 +2988,30 @@ func.func @scatter_nd_i1(%arg0: tensor<?xi32>, %arg1: tensor<?xi1>, %arg2: tenso
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testUnsortedSegmentSum
|
||||
func.func @testUnsortedSegmentSum(%arg0: tensor<8xf32>, %arg1: tensor<8xi32>, %arg2: tensor<i32>) -> tensor<8xf32> {
|
||||
// CHECK: "tfl.unsorted_segment_sum"(%arg0, %arg1, %arg2)
|
||||
%0 = "tfl.unsorted_segment_sum"(%arg0, %arg1, %arg2) : (tensor<8xf32>, tensor<8xi32>, tensor<i32>) -> tensor<8xf32>
|
||||
func.return %0 : tensor<8xf32>
|
||||
// CHECK: return %0 : tensor<8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testUnsortedSegmentProd
|
||||
func.func @testUnsortedSegmentProd(%arg0: tensor<8xf32>, %arg1: tensor<8xi32>, %arg2: tensor<i32>) -> tensor<8xf32> {
|
||||
// CHECK: "tfl.unsorted_segment_prod"(%arg0, %arg1, %arg2)
|
||||
%0 = "tfl.unsorted_segment_prod"(%arg0, %arg1, %arg2) : (tensor<8xf32>, tensor<8xi32>, tensor<i32>) -> tensor<8xf32>
|
||||
func.return %0 : tensor<8xf32>
|
||||
// CHECK: return %0 : tensor<8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testUnsortedSegmentMax
|
||||
func.func @testUnsortedSegmentMax(%arg0: tensor<8xf32>, %arg1: tensor<8xi32>, %arg2: tensor<i32>) -> tensor<8xf32> {
|
||||
// CHECK: "tfl.unsorted_segment_max"(%arg0, %arg1, %arg2)
|
||||
%0 = "tfl.unsorted_segment_max"(%arg0, %arg1, %arg2) : (tensor<8xf32>, tensor<8xi32>, tensor<i32>) -> tensor<8xf32>
|
||||
func.return %0 : tensor<8xf32>
|
||||
// CHECK: return %0 : tensor<8xf32>
|
||||
}
|
||||
|
|
|
|||
|
|
@ -220,14 +220,15 @@ def LegalizeSqrt : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
|
|||
def LegalizeSquare : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
|
||||
def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, $segment_ids),
|
||||
(TFL_SegmentSumOp $data, (CreateTFCastToInt32Op $segment_ids))>;
|
||||
def LegalizeUnsortedSegmentProd :
|
||||
Pat<(TF_UnsortedSegmentProdOp $data, $segment_ids, $num_segments),
|
||||
(TFL_UnsortedSegmentProdOp $data, (CreateTFCastToInt32Op $segment_ids),
|
||||
(CreateTFCastToInt32Op $num_segments))>;
|
||||
def LegalizeUnsortedSegmentMax :
|
||||
Pat<(TF_UnsortedSegmentMaxOp $data, $segment_ids, $num_segments),
|
||||
(TFL_UnsortedSegmentMaxOp $data, (CreateTFCastToInt32Op $segment_ids),
|
||||
(CreateTFCastToInt32Op $num_segments))>;
|
||||
foreach UnsortedSegmentOp = [
|
||||
[TF_UnsortedSegmentSumOp, TFL_UnsortedSegmentSumOp],
|
||||
[TF_UnsortedSegmentMaxOp, TFL_UnsortedSegmentMaxOp],
|
||||
[TF_UnsortedSegmentProdOp, TFL_UnsortedSegmentProdOp]] in {
|
||||
def Legalize#UnsortedSegmentOp[0] :
|
||||
Pat<(UnsortedSegmentOp[0] $data, $segment_ids, $num_segments),
|
||||
(UnsortedSegmentOp[1] $data, (CreateTFCastToInt32Op $segment_ids),
|
||||
(CreateTFCastToInt32Op $num_segments))>;
|
||||
}
|
||||
def LegalizeSelect : Pat<(TF_SelectOp $cond, $x, $y),
|
||||
(TFL_SelectOp $cond, $x, $y)>;
|
||||
def LegalizeSelectV2SameStaticShape : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y),
|
||||
|
|
|
|||
|
|
@ -182,6 +182,7 @@ def generated_test_models():
|
|||
"unroll_batch_matmul",
|
||||
"unsorted_segment_max",
|
||||
"unsorted_segment_prod",
|
||||
"unsorted_segment_sum",
|
||||
"where",
|
||||
"where_v2",
|
||||
"while",
|
||||
|
|
@ -271,6 +272,7 @@ def generated_test_models_failing(conversion_mode, delegate):
|
|||
"unique",
|
||||
"unsorted_segment_max",
|
||||
"unsorted_segment_prod",
|
||||
"unsorted_segment_sum",
|
||||
"where",
|
||||
"where_v2",
|
||||
"while",
|
||||
|
|
|
|||
|
|
@ -186,8 +186,7 @@ from tensorflow.lite.testing.op_tests.unfused_gru import make_unfused_gru_tests
|
|||
from tensorflow.lite.testing.op_tests.unique import make_unique_tests
|
||||
from tensorflow.lite.testing.op_tests.unpack import make_unpack_tests
|
||||
from tensorflow.lite.testing.op_tests.unroll_batch_matmul import make_unroll_batch_matmul_tests
|
||||
from tensorflow.lite.testing.op_tests.unsorted_segment_max import make_unsorted_segment_max_tests
|
||||
from tensorflow.lite.testing.op_tests.unsorted_segment_prod import make_unsorted_segment_prod_tests
|
||||
from tensorflow.lite.testing.op_tests.unsorted_segment import make_unsorted_segment_max_tests, make_unsorted_segment_prod_tests, make_unsorted_segment_sum_tests
|
||||
from tensorflow.lite.testing.op_tests.where import make_where_tests
|
||||
from tensorflow.lite.testing.op_tests.where_v2 import make_where_v2_tests
|
||||
from tensorflow.lite.testing.op_tests.while_loop import make_while_tests
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test configs for unsorted_segment_prod."""
|
||||
"""Test configs for unsorted_segment ops."""
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.lite.testing.zip_test_utils import create_tensor_data
|
||||
|
|
@ -20,29 +20,52 @@ from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
|
|||
from tensorflow.lite.testing.zip_test_utils import register_make_test_function
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
def make_unsorted_segment_prod_tests(options):
|
||||
"""Make a set of tests for unsorted_segment_prod op."""
|
||||
def make_unsorted_segment_tests(options, unsorted_segment_op):
|
||||
"""Make a set of tests for given unsorted_segment op."""
|
||||
test_parameters = [{
|
||||
"data": [[5]],
|
||||
"data_shape": [[5]],
|
||||
"segment_id": [[0, 1, 1, 0, 1]],
|
||||
"num_segments": [2],
|
||||
"dtype": [tf.int32, tf.float32],
|
||||
"multi_node": [0]
|
||||
}, {
|
||||
"data": [[2, 3, 4], [2, 5, 2]],
|
||||
"data_shape": [[2, 3, 4], [2, 5, 2]],
|
||||
"segment_id": [[0, 1]],
|
||||
"num_segments": [2],
|
||||
"dtype": [tf.int32, tf.float32],
|
||||
"multi_node": [0]
|
||||
}, {
|
||||
"data": [[4]],
|
||||
"data_shape": [[4]],
|
||||
"segment_id": [[0, 0, 1, 8]],
|
||||
"num_segments": [9],
|
||||
"dtype": [tf.int32, tf.float32],
|
||||
"multi_node": [0]
|
||||
}, {
|
||||
"data": [[4]],
|
||||
"data_shape": [[3]],
|
||||
"segment_id": [[-1, -2, -1]],
|
||||
"num_segments": [1],
|
||||
"dtype": [tf.int32, tf.float32],
|
||||
"multi_node": [0]
|
||||
}, {
|
||||
"data_shape": [[3]],
|
||||
"segment_id": [[-1, 0, 1]],
|
||||
"num_segments": [2],
|
||||
"dtype": [tf.int32, tf.float32],
|
||||
"multi_node": [0]
|
||||
}, {
|
||||
"data_shape": [[3, 2]],
|
||||
"segment_id": [[-1, 0, 0]],
|
||||
"num_segments": [1],
|
||||
"dtype": [tf.int32, tf.float32],
|
||||
"multi_node": [0]
|
||||
}, {
|
||||
"data_shape": [[3, 2]],
|
||||
"segment_id": [[-1, -2, -1]],
|
||||
"num_segments": [1],
|
||||
"dtype": [tf.int32, tf.float32],
|
||||
"multi_node": [0]
|
||||
}, {
|
||||
"data_shape": [[4]],
|
||||
"segment_id_shape": [[4]],
|
||||
"segment_id_min": [0],
|
||||
"segment_id_max": [1],
|
||||
|
|
@ -55,7 +78,7 @@ def make_unsorted_segment_prod_tests(options):
|
|||
|
||||
def build_graph_one_node(parameters):
|
||||
data_tensor = tf.compat.v1.placeholder(
|
||||
dtype=parameters["dtype"], name="data", shape=parameters["data"])
|
||||
dtype=parameters["dtype"], name="data", shape=parameters["data_shape"])
|
||||
segment_ids_tensor = tf.constant(
|
||||
parameters["segment_id"], dtype=tf.int32, name="segment_ids")
|
||||
num_segments_tensor = tf.constant(
|
||||
|
|
@ -63,15 +86,15 @@ def make_unsorted_segment_prod_tests(options):
|
|||
dtype=tf.int32,
|
||||
shape=[],
|
||||
name="num_segments")
|
||||
output = tf.math.unsorted_segment_prod(data_tensor, segment_ids_tensor,
|
||||
num_segments_tensor)
|
||||
output = unsorted_segment_op(data_tensor, segment_ids_tensor,
|
||||
num_segments_tensor)
|
||||
return [data_tensor], [output]
|
||||
|
||||
|
||||
# test cases for handling dynamically shaped input tensor
|
||||
def build_graph_multi_node(parameters):
|
||||
data_tensor = tf.compat.v1.placeholder(
|
||||
dtype=parameters["dtype"], name="data", shape=parameters["data"])
|
||||
dtype=parameters["dtype"], name="data", shape=parameters["data_shape"])
|
||||
segment_ids_tensor = tf.compat.v1.placeholder(
|
||||
dtype=tf.int32,
|
||||
name="segment_ids",
|
||||
|
|
@ -81,9 +104,8 @@ def make_unsorted_segment_prod_tests(options):
|
|||
dtype=tf.int32,
|
||||
shape=[],
|
||||
name="num_segments")
|
||||
intermediate_tensor = tf.math.unsorted_segment_prod(data_tensor,
|
||||
segment_ids_tensor,
|
||||
num_segments_tensor)
|
||||
intermediate_tensor = unsorted_segment_op(data_tensor, segment_ids_tensor,
|
||||
num_segments_tensor)
|
||||
segment_ids_tensor_2 = tf.constant(
|
||||
parameters["segment_id_2"], dtype=tf.int32, name="segment_ids_2")
|
||||
num_segments_tensor_2 = tf.constant(
|
||||
|
|
@ -91,9 +113,8 @@ def make_unsorted_segment_prod_tests(options):
|
|||
dtype=tf.int32,
|
||||
shape=[],
|
||||
name="num_segments_2")
|
||||
output = tf.math.unsorted_segment_prod(intermediate_tensor,
|
||||
segment_ids_tensor_2,
|
||||
num_segments_tensor_2)
|
||||
output = unsorted_segment_op(intermediate_tensor, segment_ids_tensor_2,
|
||||
num_segments_tensor_2)
|
||||
return [data_tensor, segment_ids_tensor], [output]
|
||||
|
||||
def build_graph(parameters):
|
||||
|
|
@ -105,13 +126,13 @@ def make_unsorted_segment_prod_tests(options):
|
|||
|
||||
def build_inputs_one_node(parameters, sess, inputs, outputs):
|
||||
data_value = create_tensor_data(
|
||||
parameters["dtype"], shape=parameters["data"])
|
||||
parameters["dtype"], shape=parameters["data_shape"])
|
||||
return [data_value], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [data_value])))
|
||||
|
||||
def build_inputs_multi_node(parameters, sess, inputs, outputs):
|
||||
data_value = create_tensor_data(
|
||||
dtype=parameters["dtype"], shape=parameters["data"])
|
||||
dtype=parameters["dtype"], shape=parameters["data_shape"])
|
||||
segment_id_value = create_tensor_data(
|
||||
dtype=tf.int32,
|
||||
shape=parameters["segment_id_shape"],
|
||||
|
|
@ -128,3 +149,18 @@ def make_unsorted_segment_prod_tests(options):
|
|||
return build_inputs_one_node(parameters, sess, inputs, outputs)
|
||||
|
||||
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
def make_unsorted_segment_prod_tests(options):
|
||||
make_unsorted_segment_tests(options, tf.math.unsorted_segment_prod)
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
def make_unsorted_segment_max_tests(options):
|
||||
make_unsorted_segment_tests(options, tf.math.unsorted_segment_max)
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
def make_unsorted_segment_sum_tests(options):
|
||||
make_unsorted_segment_tests(options, tf.math.unsorted_segment_sum)
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test configs for unsorted_segment_max."""
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.lite.testing.zip_test_utils import create_tensor_data
|
||||
from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
|
||||
from tensorflow.lite.testing.zip_test_utils import register_make_test_function
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
def make_unsorted_segment_max_tests(options):
|
||||
"""Make a set of tests for unsorted_segment_max op."""
|
||||
test_parameters = [{
|
||||
"data": [[5]],
|
||||
"segment_id": [[0, 1, 1, 0, 1]],
|
||||
"num_segments": [2],
|
||||
"dtype": [tf.int32, tf.float32],
|
||||
"segment_dtype": [tf.int32, tf.int64]
|
||||
}, {
|
||||
"data": [[2, 3, 4], [2, 5, 2]],
|
||||
"segment_id": [[0, 1], [-1, -1]],
|
||||
"num_segments": [2],
|
||||
"dtype": [tf.int32, tf.float32],
|
||||
"segment_dtype": [tf.int32, tf.int64]
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
data_tensor = tf.compat.v1.placeholder(
|
||||
dtype=parameters["dtype"], name="data", shape=parameters["data"])
|
||||
segment_ids_tensor = tf.constant(
|
||||
parameters["segment_id"],
|
||||
dtype=parameters["segment_dtype"],
|
||||
name="segment_ids")
|
||||
num_segments = tf.constant(
|
||||
parameters["num_segments"],
|
||||
dtype=parameters["segment_dtype"],
|
||||
shape=[],
|
||||
name="num_segments")
|
||||
output = tf.math.unsorted_segment_max(data_tensor, segment_ids_tensor,
|
||||
num_segments)
|
||||
return [data_tensor], [output]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
data_value = create_tensor_data(
|
||||
parameters["dtype"], shape=parameters["data"])
|
||||
return [data_value], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [data_value])))
|
||||
|
||||
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
|
||||
Loading…
Reference in New Issue
Block a user