mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
support bad_indices_policy for TensorScatterOp and ScatterNdUpdateOp
This CL introduces the bad_indices_policy as ScatterNd already had. PiperOrigin-RevId: 654727734
This commit is contained in:
parent
f24daf98fc
commit
14725f749c
|
|
@ -37,6 +37,9 @@
|
|||
* Add support for `stablehlo.composite`.
|
||||
* `EmbeddingLookup` op supports `TensorType_INT4` values.
|
||||
|
||||
* `tf.tensor_scatter_update`, `tf.tensor_scatter_add` and of other reduce types.
|
||||
* Support `bad_indices_policy`.
|
||||
|
||||
## Keras
|
||||
|
||||
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
|
||||
|
|
|
|||
|
|
@ -626,7 +626,7 @@ func.func @scatter_nd_add(%arg0: tensor<7xi64>, %arg1: tensor<1x1xi32>, %arg2: t
|
|||
// CHECK-LABEL: scatter_nd_add
|
||||
// CHECK: %[[GATHER:.*]] = "tf.GatherNd"(%arg0, %arg1) <{bad_indices_policy = ""}> : (tensor<7xi64>, tensor<1x1xi32>) -> tensor<1xi64>
|
||||
// CHECK: %[[ADD:.*]] = "tf.Add"(%arg2, %[[GATHER]]) : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
|
||||
// CHECK: %[[SCATTER:.*]] = "tf.TensorScatterUpdate"(%arg0, %arg1, %[[ADD]]) : (tensor<7xi64>, tensor<1x1xi32>, tensor<1xi64>) -> tensor<7xi64>
|
||||
// CHECK: %[[SCATTER:.*]] = "tf.TensorScatterUpdate"(%arg0, %arg1, %[[ADD]]) <{bad_indices_policy = ""}> : (tensor<7xi64>, tensor<1x1xi32>, tensor<1xi64>) -> tensor<7xi64>
|
||||
// CHECK: return %[[SCATTER]] : tensor<7xi64>
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ def DenseElementsAttr : ElementsAttrBase<
|
|||
"non-opaque constant tensor">;
|
||||
|
||||
def CreateGatherNdOp : NativeCodeCall<
|
||||
"$_builder.create<TF::GatherNdOp>($0.getLoc(), $0.getType(), $1, $2)">;
|
||||
"$_builder.create<TF::GatherNdOp>($0.getLoc(), $0.getType(), $1, $2, $3)">;
|
||||
|
||||
def CreateTFCastOpI32 : NativeCodeCall<
|
||||
"CreateTFCastOpI32(&$_builder, $_loc, $0, $1)">;
|
||||
|
|
@ -197,9 +197,14 @@ def ConvertMatrixSetDiagV3ToMatrixSetDiag :
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def LowerTensorScatterAdd: Pat<
|
||||
(TF_TensorScatterAddOp $input, $indices, $updates),
|
||||
(TF_TensorScatterUpdateOp $input, $indices,
|
||||
(TF_AddOp $updates, (CreateGatherNdOp $updates, $input, $indices)))>;
|
||||
(TF_TensorScatterAddOp $input, $indices, $updates, $bad_indices_policy),
|
||||
(TF_TensorScatterUpdateOp
|
||||
$input,
|
||||
$indices,
|
||||
(TF_AddOp
|
||||
$updates,
|
||||
(CreateGatherNdOp $updates, $input, $indices, $bad_indices_policy)),
|
||||
$bad_indices_policy)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddV2 op patterns.
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ func.func @xla_gather(%arg0: tensor<?x2xf32>, %arg1: tensor<1xi32>, %arg2: tenso
|
|||
// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<1> : tensor<1x1xi64>}> : () -> tensor<1x1xi64>
|
||||
// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi64>}> : () -> tensor<1xi64>
|
||||
// CHECK: %[[arg1_i64:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi64>
|
||||
// CHECK: %[[tensor_scatter_update:.*]] = "tf.TensorScatterUpdate"(%[[cst]], %[[cst_0]], %[[arg1_i64]]) : (tensor<2xi64>, tensor<1x1xi64>, tensor<1xi64>) -> tensor<2xi64>
|
||||
// CHECK: %[[tensor_scatter_update:.*]] = "tf.TensorScatterUpdate"(%[[cst]], %[[cst_0]], %[[arg1_i64]]) <{bad_indices_policy = ""}> : (tensor<2xi64>, tensor<1x1xi64>, tensor<1xi64>) -> tensor<2xi64>
|
||||
// CHECK: %[[arg2_i64:.*]] = "tf.Cast"(%arg2) <{Truncate = false}> : (tensor<2xi32>) -> tensor<2xi64>
|
||||
// CHECK: %[[slice:.*]] = "tf.Slice"(%arg0, %[[tensor_scatter_update]], %[[arg2_i64]]) : (tensor<?x2xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
|
||||
// CHECK: %[[reshape:.*]] = "tf.Reshape"(%[[slice]], %[[cst_1]]) : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32>
|
||||
|
|
@ -51,7 +51,7 @@ func.func @xla_gather_known_output_shape(%arg0: tensor<5xi32>, %arg1: tensor<1xi
|
|||
// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<0> : tensor<1x1xi64>}> : () -> tensor<1x1xi64>
|
||||
// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi64>}> : () -> tensor<0xi64>
|
||||
// CHECK: %[[arg1_i64:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi64>
|
||||
// CHECK: %[[tensor_scatter_update:.*]] = "tf.TensorScatterUpdate"(%[[cst]], %[[cst_0]], %[[arg1_i64]]) : (tensor<1xi64>, tensor<1x1xi64>, tensor<1xi64>) -> tensor<1xi64>
|
||||
// CHECK: %[[tensor_scatter_update:.*]] = "tf.TensorScatterUpdate"(%[[cst]], %[[cst_0]], %[[arg1_i64]]) <{bad_indices_policy = ""}> : (tensor<1xi64>, tensor<1x1xi64>, tensor<1xi64>) -> tensor<1xi64>
|
||||
// CHECK: %[[arg2_i64:.*]] = "tf.Cast"(%arg2) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi64>
|
||||
// CHECK: %[[slice:.*]] = "tf.Slice"(%arg0, %[[tensor_scatter_update]], %[[arg2_i64]]) : (tensor<5xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi32>
|
||||
// CHECK: %[[reshape:.*]] = "tf.Reshape"(%[[slice]], %[[cst_1]]) : (tensor<1xi32>, tensor<0xi64>) -> tensor<i32>
|
||||
|
|
|
|||
|
|
@ -13837,7 +13837,8 @@ A tensor of indices into ref.}]>:$indices,
|
|||
Arg<TF_Tensor, [{A Tensor. Must have the same type as ref. A tensor of
|
||||
values to add to ref.}]>:$updates,
|
||||
|
||||
DefaultValuedOptionalAttr<BoolAttr, "true">:$use_locking
|
||||
DefaultValuedOptionalAttr<BoolAttr, "true">:$use_locking,
|
||||
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
|
@ -13894,7 +13895,8 @@ A tensor of indices into ref.}]>:$indices,
|
|||
Arg<TF_Tensor, [{A Tensor. Must have the same type as ref. A tensor of
|
||||
values to add to ref.}]>:$updates,
|
||||
|
||||
DefaultValuedOptionalAttr<BoolAttr, "true">:$use_locking
|
||||
DefaultValuedOptionalAttr<BoolAttr, "true">:$use_locking,
|
||||
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
|
@ -13953,7 +13955,8 @@ A tensor of indices into ref.}]>:$indices,
|
|||
Arg<TF_Tensor, [{A Tensor. Must have the same type as ref. A tensor of updated
|
||||
values to add to ref.}]>:$updates,
|
||||
|
||||
DefaultValuedOptionalAttr<BoolAttr, "true">:$use_locking
|
||||
DefaultValuedOptionalAttr<BoolAttr, "true">:$use_locking,
|
||||
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
|
@ -19719,14 +19722,23 @@ numpy=array([[[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]],
|
|||
[[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]],
|
||||
[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]], dtype=int32)>
|
||||
|
||||
Note: on CPU, if an out of bound index is found, an error is returned.
|
||||
On GPU, if an out of bound index is found, the index is ignored.
|
||||
|
||||
If `indices` contains any out-of-bound indices, depending on
|
||||
`bad_indices_policy`, the op will either return an error or ignore the
|
||||
out-of-bound indices. `bad_indices_policy` can be one of the following values:
|
||||
1. "" or "DEFAULT": raises on CPU and ignore on GPU. This is because
|
||||
historically on CPU and GPU we handle errors in different ways, and for
|
||||
backward compatibility we keep the default behavior.
|
||||
2. "ERROR": raises error; GPU does not support this value.
|
||||
3. "IGNORE": ignore the bad indices; supported on both CPU and GPU.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_Tensor, [{Tensor to copy/update.}]>:$tensor,
|
||||
Arg<TF_I32OrI64Tensor, [{Index tensor.}]>:$indices,
|
||||
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates
|
||||
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates,
|
||||
|
||||
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
|
|
@ -19763,7 +19775,9 @@ Refer to `tf.tensor_scatter_nd_update` for more details.
|
|||
let arguments = (ins
|
||||
Arg<TF_Tensor, [{Tensor to update.}]>:$tensor,
|
||||
Arg<TF_I32OrI64Tensor, [{Index tensor.}]>:$indices,
|
||||
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates
|
||||
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates,
|
||||
|
||||
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
|
|
@ -19780,7 +19794,9 @@ def TF_TensorScatterMinOp : TF_Op<"TensorScatterMin", [Pure]> {
|
|||
let arguments = (ins
|
||||
Arg<TF_Tensor, [{Tensor to update.}]>:$tensor,
|
||||
Arg<TF_I32OrI64Tensor, [{Index tensor.}]>:$indices,
|
||||
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates
|
||||
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates,
|
||||
|
||||
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
|
|
@ -19864,7 +19880,9 @@ On GPU, if an out of bound index is found, the index is ignored.
|
|||
let arguments = (ins
|
||||
Arg<TF_Tensor, [{Tensor to copy/update.}]>:$tensor,
|
||||
Arg<TF_I32OrI64Tensor, [{Index tensor.}]>:$indices,
|
||||
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates
|
||||
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates,
|
||||
|
||||
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
|
|
@ -19889,8 +19907,6 @@ for the existing tensor cannot be re-used, a copy is made and updated.
|
|||
|
||||
If `indices` contains duplicates, then we pick the last update for the index.
|
||||
|
||||
If an out of bound index is found on CPU, an error is returned.
|
||||
|
||||
**WARNING**: There are some GPU specific semantics for this operation.
|
||||
- If an out of bound index is found, the index is ignored.
|
||||
- The order in which updates are applied is nondeterministic, so the output
|
||||
|
|
@ -19914,6 +19930,15 @@ The overall shape of `updates` is:
|
|||
indices.shape[:-1] + tensor.shape[indices.shape[-1]:]
|
||||
```
|
||||
|
||||
If `indices` contains any out-of-bound indices, depending on
|
||||
`bad_indices_policy`, the op will either return an error or ignore the
|
||||
out-of-bound indices. `bad_indices_policy` can be one of the following values:
|
||||
1. "" or "DEFAULT": raises on CPU and ignore on GPU. This is because
|
||||
historically on CPU and GPU we handle errors in different ways, and for
|
||||
backward compatibility we keep the default behavior.
|
||||
2. "ERROR": raises error; GPU does not support this value.
|
||||
3. "IGNORE": ignore the bad indices; supported on both CPU and GPU.
|
||||
|
||||
For usage examples see the python [tf.tensor_scatter_nd_update](
|
||||
https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function
|
||||
}];
|
||||
|
|
@ -19921,7 +19946,9 @@ https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function
|
|||
let arguments = (ins
|
||||
Arg<TF_Tensor, [{Tensor to copy/update.}]>:$tensor,
|
||||
Arg<TensorOf<[TF_Int16, TF_Int32, TF_Int64, TF_Uint16]>, [{Index tensor.}]>:$indices,
|
||||
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates
|
||||
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates,
|
||||
|
||||
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
|
|
|
|||
|
|
@ -616,7 +616,7 @@ func.func @decompose_resource_scatter_add_op(%indices : tensor<2x?xi32>, %update
|
|||
|
||||
// CHECK-DAG: [[READ:%.+]] = "tf.ReadVariableOp"([[VAR]])
|
||||
// CHECK: [[EXPAND:%.+]] = "tf.ExpandDims"([[INDEX]], [[CST]]) : (tensor<2x?xi32>, tensor<i32>) -> tensor<2x?x1xi32>
|
||||
// CHECK: [[TENSOR:%.+]] = "tf.TensorScatterAdd"([[READ]], [[EXPAND]], [[UPDATE]]) : (tensor<*xi32>, tensor<2x?x1xi32>, tensor<?x?x?xi32>) -> tensor<*xi32>
|
||||
// CHECK: [[TENSOR:%.+]] = "tf.TensorScatterAdd"([[READ]], [[EXPAND]], [[UPDATE]]) <{bad_indices_policy = ""}> : (tensor<*xi32>, tensor<2x?x1xi32>, tensor<?x?x?xi32>) -> tensor<*xi32>
|
||||
// CHECK: "tf.AssignVariableOp"([[VAR]], [[TENSOR]])
|
||||
"tf.ResourceScatterAdd"(%resource, %indices, %updates) : (tensor<*x!tf_type.resource<tensor<*xi32>>>, tensor<2x?xi32>, tensor<?x?x?xi32>) -> ()
|
||||
|
||||
|
|
@ -637,7 +637,7 @@ func.func @decompose_resource_scatter_add_op_1d_indices(%indices : tensor<?xi32>
|
|||
|
||||
// CHECK-DAG: [[READ:%.+]] = "tf.ReadVariableOp"([[VAR]])
|
||||
// CHECK: [[EXPAND:%.+]] = "tf.ExpandDims"([[INDEX]], [[CST]]) : (tensor<?xi32>, tensor<i32>) -> tensor<?x1xi32>
|
||||
// CHECK: [[TENSOR:%.+]] = "tf.TensorScatterAdd"([[READ]], [[EXPAND]], [[UPDATE]]) : (tensor<*xi32>, tensor<?x1xi32>, tensor<?x?x?xi32>) -> tensor<*xi32>
|
||||
// CHECK: [[TENSOR:%.+]] = "tf.TensorScatterAdd"([[READ]], [[EXPAND]], [[UPDATE]]) <{bad_indices_policy = ""}> : (tensor<*xi32>, tensor<?x1xi32>, tensor<?x?x?xi32>) -> tensor<*xi32>
|
||||
// CHECK: "tf.AssignVariableOp"([[VAR]], [[TENSOR]])
|
||||
"tf.ResourceScatterAdd"(%resource, %indices, %updates) : (tensor<*x!tf_type.resource<tensor<*xi32>>>, tensor<?xi32>, tensor<?x?x?xi32>) -> ()
|
||||
|
||||
|
|
@ -686,7 +686,7 @@ func.func @decompose_resource_scatter_update_op(%indices : tensor<2x?xi32>, %upd
|
|||
|
||||
// CHECK-DAG: [[READ:%.+]] = "tf.ReadVariableOp"([[VAR]])
|
||||
// CHECK: [[EXPAND:%.+]] = "tf.ExpandDims"([[INDEX]], [[CST]]) : (tensor<2x?xi32>, tensor<i32>) -> tensor<2x?x1xi32>
|
||||
// CHECK: [[TENSOR:%.+]] = "tf.TensorScatterUpdate"([[READ]], [[EXPAND]], [[UPDATE]]) : (tensor<*xi32>, tensor<2x?x1xi32>, tensor<?x?x?xi32>) -> tensor<*xi32>
|
||||
// CHECK: [[TENSOR:%.+]] = "tf.TensorScatterUpdate"([[READ]], [[EXPAND]], [[UPDATE]]) <{bad_indices_policy = ""}> : (tensor<*xi32>, tensor<2x?x1xi32>, tensor<?x?x?xi32>) -> tensor<*xi32>
|
||||
// CHECK: "tf.AssignVariableOp"([[VAR]], [[TENSOR]])
|
||||
"tf.ResourceScatterUpdate"(%resource, %indices, %updates) : (tensor<*x!tf_type.resource<tensor<*xi32>>>, tensor<2x?xi32>, tensor<?x?x?xi32>) -> ()
|
||||
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ func.func @invert_permutation(%arg0: tensor<5xi32>) -> tensor<5xi32> {
|
|||
// CHECK-DAG: %[[cst_3:.*]] = "tf.Const"() <{value = dense<0> : tensor<5xi32>}> : () -> tensor<5xi32>
|
||||
|
||||
// CHECK-DAG: %[[INDICES:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<5xi32>, tensor<2xi32>) -> tensor<5x1xi32>
|
||||
// CHECK-DAG: %[[INDICES_1:.*]] = "tf.TensorScatterAdd"(%[[cst_3]], %[[INDICES]], %[[cst_2]]) : (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
|
||||
// CHECK-DAG: %[[INDICES_1:.*]] = "tf.TensorScatterAdd"(%[[cst_3]], %[[INDICES]], %[[cst_2]]) <{bad_indices_policy = ""}> : (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
|
||||
// CHECK-DAG: %[[INDICES_2:.*]] = "tf.Sub"(%[[cst_1]], %[[INDICES_1]]) : (tensor<i32>, tensor<5xi32>) -> tensor<5xi32>
|
||||
// CHECK-DAG: %[[INDICES_3:.*]] = "tf.Mul"(%[[INDICES_2]], %arg0) : (tensor<5xi32>, tensor<5xi32>) -> tensor<5xi32>
|
||||
// CHECK-DAG: %[[INDICES_4:.*]] = "tf.TensorScatterAdd"(%[[cst_3]], %0, %[[UPDATES]]) : (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
|
||||
// CHECK-DAG: %[[INDICES_4:.*]] = "tf.TensorScatterAdd"(%[[cst_3]], %0, %[[UPDATES]]) <{bad_indices_policy = ""}> : (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
|
||||
// CHECK-DAG: %[[INDICES_5:.*]] = "tf.AddV2"(%[[INDICES_3]], %[[INDICES_4]]) : (tensor<5xi32>, tensor<5xi32>) -> tensor<5xi32>
|
||||
%0 = "tf.InvertPermutation"(%arg0) : (tensor<5xi32>) -> tensor<5xi32>
|
||||
func.return %0 : tensor<5xi32>
|
||||
|
|
@ -825,7 +825,7 @@ func.func @Inv_i32(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
|||
// CHECK-LABEL: @ScatterNd
|
||||
func.func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> {
|
||||
// CHECK: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<8xf32>}> : () -> tensor<8xf32>
|
||||
// CHECK: "tf.TensorScatterAdd"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32>
|
||||
// CHECK: "tf.TensorScatterAdd"(%[[ZERO]], %arg0, %arg1) <{bad_indices_policy = ""}> : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32>
|
||||
|
||||
%shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%0 = "tf.ScatterNd"(%arg0, %arg1, %shape) : (tensor<4x1xi32>, tensor<4xf32>, tensor<1xi32>) -> tensor<8xf32>
|
||||
|
|
@ -1216,10 +1216,10 @@ func.func @scatter_nd_updates(%arg0: tensor<14xf32>, %arg1: tensor<1x1xi32>, %ar
|
|||
// CHECK-DAG: %[[CST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
|
||||
// CHECK-DAG: %[[CST0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
|
||||
// CHECK-DAG: %[[CST1:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<14xf32>}> : () -> tensor<14xf32>
|
||||
// CHECK: %[[SCATTER:.*]] = "tf.TensorScatterAdd"(%cst_1, %arg1, %[[CST0]]) : (tensor<14xf32>, tensor<1x1xi32>, tensor<1xf32>) -> tensor<14xf32>
|
||||
// CHECK: %[[SCATTER:.*]] = "tf.TensorScatterAdd"(%cst_1, %arg1, %[[CST0]]) <{bad_indices_policy = ""}> : (tensor<14xf32>, tensor<1x1xi32>, tensor<1xf32>) -> tensor<14xf32>
|
||||
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[CST]], %[[SCATTER]]) : (tensor<f32>, tensor<14xf32>) -> tensor<14xf32>
|
||||
// CHECK: %[[MUL:.*]] = "tf.Mul"(%[[SUB]], %arg0) : (tensor<14xf32>, tensor<14xf32>) -> tensor<14xf32>
|
||||
// CHECK: %[[SCATTER1:.*]] = "tf.TensorScatterAdd"(%[[CST1]], %arg1, %arg2) : (tensor<14xf32>, tensor<1x1xi32>, tensor<1xf32>) -> tensor<14xf32>
|
||||
// CHECK: %[[SCATTER1:.*]] = "tf.TensorScatterAdd"(%[[CST1]], %arg1, %arg2) <{bad_indices_policy = ""}> : (tensor<14xf32>, tensor<1x1xi32>, tensor<1xf32>) -> tensor<14xf32>
|
||||
// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[MUL]], %[[SCATTER1]]) : (tensor<14xf32>, tensor<14xf32>) -> tensor<14xf32>
|
||||
// CHECK: return %[[ADD]] : tensor<14xf32>
|
||||
}
|
||||
|
|
@ -1234,10 +1234,10 @@ func.func @scatter_nd_updates_bool(%arg0: tensor<1x24xi1>, %arg1: tensor<1x2x2xi
|
|||
// CHECK-DAG: %[[CST1:.*]] = "tf.Const"() <{value = dense<0> : tensor<1x24xi32>}> : () -> tensor<1x24xi32>
|
||||
// CHECK: %[[CAST0:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1x24xi1>) -> tensor<1x24xi32>
|
||||
// CHECK: %[[CAST1:.*]] = "tf.Cast"(%arg2) <{Truncate = false}> : (tensor<1x2xi1>) -> tensor<1x2xi32>
|
||||
// CHECK: %[[SCATTER:.*]] = "tf.TensorScatterAdd"(%[[CST1]], %arg1, %[[CST0]]) : (tensor<1x24xi32>, tensor<1x2x2xi32>, tensor<1x2xi32>) -> tensor<1x24xi32>
|
||||
// CHECK: %[[SCATTER:.*]] = "tf.TensorScatterAdd"(%[[CST1]], %arg1, %[[CST0]]) <{bad_indices_policy = ""}> : (tensor<1x24xi32>, tensor<1x2x2xi32>, tensor<1x2xi32>) -> tensor<1x24xi32>
|
||||
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[CST]], %[[SCATTER]]) : (tensor<i32>, tensor<1x24xi32>) -> tensor<1x24xi32>
|
||||
// CHECK: %[[MUL:.*]] = "tf.Mul"(%[[SUB]], %[[CAST0]]) : (tensor<1x24xi32>, tensor<1x24xi32>) -> tensor<1x24xi32>
|
||||
// CHECK: %[[SCATTER1:.*]] = "tf.TensorScatterAdd"(%[[CST1]], %arg1, %[[CAST1]]) : (tensor<1x24xi32>, tensor<1x2x2xi32>, tensor<1x2xi32>) -> tensor<1x24xi32>
|
||||
// CHECK: %[[SCATTER1:.*]] = "tf.TensorScatterAdd"(%[[CST1]], %arg1, %[[CAST1]]) <{bad_indices_policy = ""}> : (tensor<1x24xi32>, tensor<1x2x2xi32>, tensor<1x2xi32>) -> tensor<1x24xi32>
|
||||
// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[MUL]], %[[SCATTER1]]) : (tensor<1x24xi32>, tensor<1x24xi32>) -> tensor<1x24xi32>
|
||||
// CHECK: %[[CAST2:.*]] = "tf.Cast"(%[[ADD]]) <{Truncate = false}> : (tensor<1x24xi32>) -> tensor<1x24xi1>
|
||||
// CHECK: return %[[CAST2]] : tensor<1x24xi1>
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ func.func @main(%arg0: tensor<10x8x9xf32>, %arg1: tensor<5xi32>, %arg2: tensor<5
|
|||
%tl = "tf.TensorListFromTensor"(%arg0, %elem_shape) : (tensor<10x8x9xf32>, tensor<2xi32>) -> tensor<!tf_type.variant<tensor<8x9xf32>>>
|
||||
// CHECK: %[[IND_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 1]> : tensor<2xi32>}> : () -> tensor<2xi32>
|
||||
// CHECK: %[[IND_RESHPE:.*]] = "tf.Reshape"(%[[ARG1]], %[[IND_SHAPE]]) : (tensor<5xi32>, tensor<2xi32>) -> tensor<5x1xi32>
|
||||
// CHECK: %[[SC:.*]] = "tf.TensorScatterUpdate"(%[[BUFFER]], %[[IND_RESHPE]], %[[ARG2]]) : (tensor<10x8x9xf32>, tensor<5x1xi32>, tensor<5x8x9xf32>) -> tensor<10x8x9xf32>
|
||||
// CHECK: %[[SC:.*]] = "tf.TensorScatterUpdate"(%[[BUFFER]], %[[IND_RESHPE]], %[[ARG2]]) <{bad_indices_policy = ""}> : (tensor<10x8x9xf32>, tensor<5x1xi32>, tensor<5x8x9xf32>) -> tensor<10x8x9xf32>
|
||||
%scatter = "tf.TensorListScatterIntoExistingList"(%tl, %arg2, %arg1) : (tensor<!tf_type.variant<tensor<8x9xf32>>>, tensor<5x8x9xf32>, tensor<5xi32>) -> tensor<!tf_type.variant<tensor<8x9xf32>>>
|
||||
%stack = "tf.TensorListStack"(%scatter, %elem_shape) : (tensor<!tf_type.variant<tensor<8x9xf32>>>, tensor<2xi32>) -> tensor<10x8x9xf32>
|
||||
// CHECK: return %[[SC]] : tensor<10x8x9xf32>
|
||||
|
|
|
|||
|
|
@ -38,6 +38,14 @@ def CheckHasResourceSubtype : Constraint<CPred<"HasResourceSubtype($0)">>;
|
|||
|
||||
def CreateConstBoolAttrFalse : NativeCodeCall<"$_builder.getBoolAttr(false)">;
|
||||
|
||||
def CreateTensorScatterAddOp : NativeCodeCall<
|
||||
"$_builder.create<TF::TensorScatterAddOp>("
|
||||
"$0.getLoc(), $0.getType(), $0, $1, $2, $_builder.getStringAttr(\"\"))">;
|
||||
|
||||
def CreateTensorScatterUpdateOp : NativeCodeCall<
|
||||
"$_builder.create<TF::TensorScatterUpdateOp>("
|
||||
"$0.getLoc(), $0.getType(), $0, $1, $2, $_builder.getStringAttr(\"\"))">;
|
||||
|
||||
def CreateTFReadVariableOpFromResourceHandle : NativeCodeCall<
|
||||
"$_builder.create<TF::ReadVariableOp>("
|
||||
"$0.getLoc(), GetResourceSubtype($1), $1)">;
|
||||
|
|
@ -354,7 +362,7 @@ def DecomposeResourceScatterAdd : Pat<
|
|||
(TF_ResourceScatterAddOp:$src_op $resource, $indices, $updates),
|
||||
(TF_AssignVariableOp
|
||||
$resource,
|
||||
(TF_TensorScatterAddOp
|
||||
(CreateTensorScatterAddOp
|
||||
(CreateTFReadVariableOp $src_op, $updates, $resource),
|
||||
(TF_ExpandDimsOp $indices,
|
||||
(TF_ConstOp (GetScalarOfType<-1> $indices))),
|
||||
|
|
@ -369,7 +377,7 @@ def DecomposeResourceScatterUpdate : Pat<
|
|||
(TF_ResourceScatterUpdateOp:$src_op $resource, $indices, $updates),
|
||||
(TF_AssignVariableOp
|
||||
$resource,
|
||||
(TF_TensorScatterUpdateOp
|
||||
(CreateTensorScatterUpdateOp
|
||||
(CreateTFReadVariableOp $src_op, $updates, $resource),
|
||||
(TF_ExpandDimsOp $indices,
|
||||
(TF_ConstOp (GetScalarOfType<-1> $indices))),
|
||||
|
|
|
|||
|
|
@ -49,7 +49,12 @@ def CreateTFCastOpI32 : NativeCodeCall<
|
|||
"CreateTFCastOpI32(&$_builder, $0.getLoc(), $1, $2)">;
|
||||
|
||||
def CreateTensorScatterNdOp : NativeCodeCall<
|
||||
"$_builder.create<TF::ScatterNdOp>($0.getLoc(), $0.getType(), $1, $2, $3)">;
|
||||
"$_builder.create<TF::ScatterNdOp>("
|
||||
"$0.getLoc(), $0.getType(), $1, $2, $3, $4)">;
|
||||
|
||||
def CreateTensorScatterUpdateOp : NativeCodeCall<
|
||||
"$_builder.create<TF::TensorScatterUpdateOp>("
|
||||
"$0.getLoc(), $0.getType(), $0, $1, $2, $3)">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Add op patterns.
|
||||
|
|
@ -462,25 +467,13 @@ def LowerOnesLikeOp : LowerInitializationOp<TF_OnesLikeOp, 1>;
|
|||
|
||||
def LowerScatterNdOp :
|
||||
Pat<(TF_ScatterNdOp $indices,
|
||||
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$updates, $shape, ConstantStrAttr<StrAttr, "">),
|
||||
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:
|
||||
$updates,
|
||||
$shape,
|
||||
$bad_indices_policy),
|
||||
(TF_TensorScatterAddOp
|
||||
(TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))),
|
||||
$indices, $updates)>;
|
||||
|
||||
def LowerScatterNdOpDefaultBadIndicesPolicy :
|
||||
Pat<(TF_ScatterNdOp $indices,
|
||||
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$updates, $shape, ConstantStrAttr<StrAttr, "DEFAULT">),
|
||||
(TF_TensorScatterAddOp
|
||||
(TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))),
|
||||
$indices, $updates)>;
|
||||
|
||||
|
||||
def LowerScatterNdOpIgnoreBadIndices :
|
||||
Pat<(TF_ScatterNdOp $indices,
|
||||
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$updates, $shape, ConstantStrAttr<StrAttr, "IGNORE">),
|
||||
(TF_TensorScatterAddOp
|
||||
(TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))),
|
||||
$indices, $updates)>;
|
||||
$indices, $updates, $bad_indices_policy)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Xdivy, Xlog1p and Xlogy op patterns.
|
||||
|
|
@ -526,9 +519,9 @@ def LowerIsFiniteOp : Pat<(TF_IsFiniteOp $x),
|
|||
// TensorScatterUpdate op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def LowerTensorScatterUpdate_1 : Pat<(TF_TensorScatterUpdateOp $input, $indices, $updates),
|
||||
def LowerTensorScatterUpdate_1 : Pat<(TF_TensorScatterUpdateOp$input, $indices, $updates, $bad_indices_policy),
|
||||
(TF_CastOp
|
||||
(TF_TensorScatterUpdateOp
|
||||
(CreateTensorScatterUpdateOp
|
||||
(CreateTFCastOpI32
|
||||
$input,
|
||||
$input,
|
||||
|
|
@ -537,20 +530,29 @@ def LowerTensorScatterUpdate_1 : Pat<(TF_TensorScatterUpdateOp $input, $indices,
|
|||
(CreateTFCastOpI32
|
||||
$updates,
|
||||
$updates,
|
||||
/*truncate=*/ConstBoolAttrFalse)),
|
||||
/*truncate=*/ConstBoolAttrFalse),
|
||||
$bad_indices_policy),
|
||||
/*truncate=*/ConstBoolAttrFalse),
|
||||
[(TensorOf<[TF_Bool]> $input), (TensorOf<[TF_Bool]> $updates)] >;
|
||||
|
||||
def LowerTensorScatterUpdate_2 : Pat<(TF_TensorScatterUpdateOp $input, $indices, $updates),
|
||||
def LowerTensorScatterUpdate_2 : Pat<(TF_TensorScatterUpdateOp $input, $indices, $updates, $bad_indices_policy),
|
||||
(TF_AddV2Op
|
||||
(TF_MulOp
|
||||
(TF_SubOp (TF_ConstOp (GetScalarOfType<1> $updates)),
|
||||
(CreateTensorScatterNdOp $input, $indices,
|
||||
(TF_OnesLikeOp $updates),
|
||||
(CreateTFShapeOp $input, $input, ConstBoolAttrTrue))),
|
||||
(CreateTFShapeOp
|
||||
$input,
|
||||
$input,
|
||||
ConstBoolAttrTrue),
|
||||
$bad_indices_policy)),
|
||||
$input),
|
||||
(CreateTensorScatterNdOp $input, $indices, $updates,
|
||||
(CreateTFShapeOp $input, $input, ConstBoolAttrTrue))),
|
||||
(CreateTFShapeOp
|
||||
$input,
|
||||
$input,
|
||||
ConstBoolAttrTrue),
|
||||
$bad_indices_policy)),
|
||||
[(TensorOf<[TF_Int, TF_Float, TF_Complex]> $updates)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
|||
|
|
@ -83,7 +83,15 @@ numpy=array([[[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]],
|
|||
[[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]],
|
||||
[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]], dtype=int32)>
|
||||
|
||||
Note: on CPU, if an out of bound index is found, an error is returned.
|
||||
On GPU, if an out of bound index is found, the index is ignored.
|
||||
|
||||
If `indices` contains any out-of-bound indices, depending on
|
||||
`bad_indices_policy`, the op will either return an error or ignore the
|
||||
out-of-bound indices. `bad_indices_policy` can be one of the following values:
|
||||
1. "" or "DEFAULT": raises on CPU and ignore on GPU. This is because
|
||||
historically on CPU and GPU we handle errors in different ways, and for
|
||||
backward compatibility we keep the default behavior.
|
||||
2. "ERROR": raises error; GPU does not support this value.
|
||||
3. "IGNORE": ignore the bad indices; supported on both CPU and GPU.
|
||||
|
||||
END
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,8 +35,6 @@ for the existing tensor cannot be re-used, a copy is made and updated.
|
|||
|
||||
If `indices` contains duplicates, then we pick the last update for the index.
|
||||
|
||||
If an out of bound index is found on CPU, an error is returned.
|
||||
|
||||
**WARNING**: There are some GPU specific semantics for this operation.
|
||||
- If an out of bound index is found, the index is ignored.
|
||||
- The order in which updates are applied is nondeterministic, so the output
|
||||
|
|
@ -60,6 +58,15 @@ The overall shape of `updates` is:
|
|||
indices.shape[:-1] + tensor.shape[indices.shape[-1]:]
|
||||
```
|
||||
|
||||
If `indices` contains any out-of-bound indices, depending on
|
||||
`bad_indices_policy`, the op will either return an error or ignore the
|
||||
out-of-bound indices. `bad_indices_policy` can be one of the following values:
|
||||
1. "" or "DEFAULT": raises on CPU and ignore on GPU. This is because
|
||||
historically on CPU and GPU we handle errors in different ways, and for
|
||||
backward compatibility we keep the default behavior.
|
||||
2. "ERROR": raises error; GPU does not support this value.
|
||||
3. "IGNORE": ignore the bad indices; supported on both CPU and GPU.
|
||||
|
||||
For usage examples see the python [tf.tensor_scatter_nd_update](
|
||||
https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function
|
||||
|
||||
|
|
|
|||
|
|
@ -5569,6 +5569,8 @@ tf_cuda_cc_test(
|
|||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -79,13 +79,10 @@ bool ValidEmptyOutputShape(int64_t num_inputs, int64_t num_indices,
|
|||
return (num_inputs != 0 && num_indices != 0 && num_updates != 0);
|
||||
}
|
||||
|
||||
template <typename Device, typename T, typename Index>
|
||||
class ScatterNdOp : public OpKernel {
|
||||
template <typename Device>
|
||||
class ScatterOpBase : public OpKernel {
|
||||
public:
|
||||
explicit ScatterNdOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||
const DataType dt = DataTypeToEnum<T>::v();
|
||||
const DataType index_t = DataTypeToEnum<Index>::v();
|
||||
OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt}));
|
||||
explicit ScatterOpBase(OpKernelConstruction* c) : OpKernel(c) {
|
||||
std::string bad_indices_policy_str;
|
||||
OP_REQUIRES_OK(c,
|
||||
c->GetAttr(kBadIndicesPolicyAtrr, &bad_indices_policy_str));
|
||||
|
|
@ -101,6 +98,19 @@ class ScatterNdOp : public OpKernel {
|
|||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
BadIndicesPolicy bad_indices_policy_ = BadIndicesPolicy::kDefault;
|
||||
};
|
||||
|
||||
template <typename Device, typename T, typename Index>
|
||||
class ScatterNdOp : public ScatterOpBase<Device> {
|
||||
public:
|
||||
explicit ScatterNdOp(OpKernelConstruction* c) : ScatterOpBase<Device>(c) {
|
||||
const DataType dt = DataTypeToEnum<T>::v();
|
||||
const DataType index_t = DataTypeToEnum<Index>::v();
|
||||
OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt}));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
const Tensor& indices = c->input(0);
|
||||
const Tensor& updates = c->input(1);
|
||||
|
|
@ -163,19 +173,16 @@ class ScatterNdOp : public OpKernel {
|
|||
OP_REQUIRES_OK(
|
||||
c, functor::DoScatterNd<Device, T, Index, scatter_nd_op::UpdateOp::ADD>(
|
||||
c, indices, updates, shape, &out, true /*allocate*/,
|
||||
bad_indices_policy_));
|
||||
this->bad_indices_policy_));
|
||||
c->set_output(0, out);
|
||||
}
|
||||
|
||||
private:
|
||||
BadIndicesPolicy bad_indices_policy_ = BadIndicesPolicy::kDefault;
|
||||
};
|
||||
|
||||
template <typename Device, typename T, typename Index,
|
||||
scatter_nd_op::UpdateOp op>
|
||||
class TensorScatterOp : public OpKernel {
|
||||
class TensorScatterOp : public ScatterOpBase<Device> {
|
||||
public:
|
||||
explicit TensorScatterOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||
explicit TensorScatterOp(OpKernelConstruction* c) : ScatterOpBase<Device>(c) {
|
||||
const DataType dt = DataTypeToEnum<T>::v();
|
||||
const DataType index_t = DataTypeToEnum<Index>::v();
|
||||
OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt}));
|
||||
|
|
@ -251,14 +258,14 @@ class TensorScatterOp : public OpKernel {
|
|||
|
||||
OP_REQUIRES_OK(c, tensorflow::functor::DoCopy(c->eigen_device<Device>(),
|
||||
input, out));
|
||||
OP_REQUIRES_OK(c,
|
||||
functor::DoScatterNd<Device, T, Index, op>(
|
||||
c, indices, updates, shape, out, false /*allocate*/));
|
||||
OP_REQUIRES_OK(c, functor::DoScatterNd<Device, T, Index, op>(
|
||||
c, indices, updates, shape, out, false /*allocate*/,
|
||||
this->bad_indices_policy_));
|
||||
} else {
|
||||
// Output forwarded, so simply perform the scatter.
|
||||
OP_REQUIRES_OK(c, functor::DoScatterNd<Device, T, Index, op>(
|
||||
c, indices, updates, shape, forwarded_input.get(),
|
||||
false /*allocate*/));
|
||||
false /*allocate*/, this->bad_indices_policy_));
|
||||
|
||||
c->set_output(0, *forwarded_input);
|
||||
}
|
||||
|
|
@ -267,9 +274,10 @@ class TensorScatterOp : public OpKernel {
|
|||
|
||||
template <typename Device, typename T, typename Index,
|
||||
scatter_nd_op::UpdateOp op>
|
||||
class ScatterNdUpdateOp : public OpKernel {
|
||||
class ScatterNdUpdateOp : public ScatterOpBase<Device> {
|
||||
public:
|
||||
explicit ScatterNdUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||
explicit ScatterNdUpdateOp(OpKernelConstruction* c)
|
||||
: ScatterOpBase<Device>(c) {
|
||||
const DataType dt = DataTypeToEnum<T>::v();
|
||||
const DataType dt_ref = DataTypeToEnum<T>::ref();
|
||||
const DataType index_t = DataTypeToEnum<Index>::v();
|
||||
|
|
@ -344,10 +352,9 @@ class ScatterNdUpdateOp : public OpKernel {
|
|||
params = *params_ptr;
|
||||
}
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
c, functor::DoScatterNd<Device, T, Index, op>(
|
||||
c, indices, updates, params_shape, ¶ms, false /*allocate*/));
|
||||
OP_REQUIRES_OK(c, functor::DoScatterNd<Device, T, Index, op>(
|
||||
c, indices, updates, params_shape, ¶ms,
|
||||
false /*allocate*/, this->bad_indices_policy_));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -13,15 +13,20 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
|
|
@ -33,10 +38,182 @@ limitations under the License.
|
|||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tsl/lib/core/status_test_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class TensorScatterUpdateOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void MakeOp(DataType variable_type, DataType index_type) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", "TensorScatterUpdate")
|
||||
.Input(FakeInput(variable_type))
|
||||
.Input(FakeInput(index_type))
|
||||
.Input(FakeInput(variable_type))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TensorScatterUpdateOpTest, Simple_TwoD32) {
|
||||
MakeOp(DT_FLOAT, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
|
||||
AddInputFromArray<float>(TensorShape({3, 3}),
|
||||
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Check output.
|
||||
Tensor params_tensor = *mutable_input(0).tensor;
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
|
||||
test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
|
||||
10002, 0, 0, 0, 777, 778, 779});
|
||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(TensorScatterUpdateOpTest, Simple_Two64) {
|
||||
MakeOp(DT_FLOAT, DT_INT64);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int64_t>(TensorShape({3, 1}), {0, 4, 2});
|
||||
AddInputFromArray<float>(TensorShape({3, 3}),
|
||||
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Check output
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
|
||||
test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
|
||||
10002, 0, 0, 0, 777, 778, 779});
|
||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(TensorScatterUpdateOpTest, Simple_ZeroD) {
|
||||
MakeOp(DT_FLOAT, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5, 1}), {0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({1, 1}), {3});
|
||||
AddInputFromArray<float>(TensorShape({1, 1}), {101});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Check output
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 1}));
|
||||
test::FillValues<float>(&expected, {0, 0, 0, 101, 0});
|
||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(TensorScatterUpdateOpTest, Simple_OneD) {
|
||||
MakeOp(DT_FLOAT, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5, 1}), {0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
|
||||
AddInputFromArray<float>(TensorShape({3, 1}), {100, 101, 102});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Check output
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 1}));
|
||||
test::FillValues<float>(&expected, {100, 0, 102, 0, 101});
|
||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(TensorScatterUpdateOpTest, HigherRank) {
|
||||
MakeOp(DT_FLOAT, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({2, 3, 1}), {0, 4, 2, 1, 3, 6});
|
||||
AddInputFromArray<float>(TensorShape({2, 3}), {10, 20, 30, 40, 50, 60});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Check output
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({8}));
|
||||
test::FillValues<float>(&expected, {10, 40, 30, 50, 20, 0, 60, 0});
|
||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(TensorScatterUpdateOpTest, Error_IndexOutOfRange) {
|
||||
MakeOp(DT_FLOAT, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 99, 4});
|
||||
AddInputFromArray<float>(TensorShape({3, 3}),
|
||||
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
|
||||
Status s = RunOpKernel();
|
||||
EXPECT_TRUE(absl::StrContains(
|
||||
s.ToString(), "indices[1] = [99] does not index into shape [5,3]"))
|
||||
<< s;
|
||||
}
|
||||
|
||||
class TensorScatterUpdateOpErrorOnBadIndicesTest : public OpsTestBase {
|
||||
protected:
|
||||
void MakeOp(DataType variable_type, DataType index_type) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", "TensorScatterUpdate")
|
||||
.Input(FakeInput(variable_type))
|
||||
.Input(FakeInput(index_type))
|
||||
.Input(FakeInput(variable_type))
|
||||
.Attr("bad_indices_policy", "ERROR")
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TensorScatterUpdateOpErrorOnBadIndicesTest, Error_IndexOutOfRange) {
|
||||
MakeOp(DT_FLOAT, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 99, 4});
|
||||
AddInputFromArray<float>(TensorShape({3, 3}),
|
||||
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
|
||||
Status s = RunOpKernel();
|
||||
EXPECT_TRUE(absl::StrContains(
|
||||
s.ToString(), "indices[1] = [99] does not index into shape [5,3]"))
|
||||
<< s;
|
||||
}
|
||||
|
||||
class TensorScatterUpdateOpIgnoreBadIndicesTest : public OpsTestBase {
|
||||
protected:
|
||||
void MakeOp(DataType variable_type, DataType index_type) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", "TensorScatterUpdate")
|
||||
.Input(FakeInput(variable_type))
|
||||
.Input(FakeInput(index_type))
|
||||
.Input(FakeInput(variable_type))
|
||||
.Attr("bad_indices_policy", "IGNORE")
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TensorScatterUpdateOpIgnoreBadIndicesTest, DropOutOfRangeIndices) {
|
||||
MakeOp(DT_FLOAT, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
// tensor: output tensor of 5x1 shape, initialized to 0.
|
||||
AddInputFromArray<float>(TensorShape({5, 1}), {0, 0, 0, 0, 0});
|
||||
// Put the bad index in the middle to make sure the others are still updated.
|
||||
// Index: [[0], [5], [2]].
|
||||
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 5, 2});
|
||||
// Updates: [100, 101, 102].
|
||||
AddInputFromArray<float>(TensorShape({3, 1}), {100, 101, 102});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Check the output.
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 1}));
|
||||
// The valid index range is [0,5). Expect to drop index[1] of value "5" and
|
||||
// update otuput[0] and output[2].
|
||||
test::FillValues<float>(&expected, {100, 0, 102, 0, 0});
|
||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
class ScatterNdUpdateOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void MakeOp(DataType variable_ref_type, DataType index_type) {
|
||||
|
|
@ -241,6 +418,80 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
|
|||
<< s;
|
||||
}
|
||||
|
||||
class ScatterNdUpdateOpErrorOnBadIndicesTest : public OpsTestBase {
|
||||
protected:
|
||||
void MakeOp(DataType variable_ref_type, DataType index_type) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNdUpdate")
|
||||
.Input(FakeInput(variable_ref_type))
|
||||
.Input(FakeInput(index_type))
|
||||
.Input(FakeInput(RemoveRefType(variable_ref_type)))
|
||||
.Attr("bad_indices_policy", "ERROR")
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ScatterNdUpdateOpErrorOnBadIndicesTest, Error_IndexOutOfRange) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 99, 4});
|
||||
AddInputFromArray<float>(TensorShape({3, 3}),
|
||||
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
|
||||
Status s = RunOpKernel();
|
||||
EXPECT_TRUE(absl::StrContains(
|
||||
s.ToString(), "indices[1] = [99] does not index into shape [5,3]"))
|
||||
<< s;
|
||||
}
|
||||
|
||||
class ScatterNdUpdateOpIgnoreBadIndicesTest : public OpsTestBase {
|
||||
protected:
|
||||
void MakeOp(DataType variable_ref_type, DataType index_type) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNdUpdate")
|
||||
.Input(FakeInput(variable_ref_type))
|
||||
.Input(FakeInput(index_type))
|
||||
.Input(FakeInput(RemoveRefType(variable_ref_type)))
|
||||
.Attr("bad_indices_policy", "IGNORE")
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ScatterNdUpdateOpIgnoreBadIndicesTest, DropOutOfRangeIndices) {
|
||||
MakeOp(DT_FLOAT_REF, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
// Put the bad index in the middle to make sure the others are still updated.
|
||||
// ref: output tensor of 5x1 shape, initialized to 0.
|
||||
AddInputFromArray<float>(TensorShape({5, 1}), {0, 0, 0, 0, 0});
|
||||
// Index: [[0], [5], [2]].
|
||||
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 5, 2});
|
||||
// Updates: [100, 101, 102].
|
||||
AddInputFromArray<float>(TensorShape({3, 1}), {100, 101, 102});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Check the output.
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 1}));
|
||||
// The valid index range is [0,5). Expect to drop index[1] of value "5" and
|
||||
// update otuput[0] and output[2].
|
||||
test::FillValues<float>(&expected, {100, 0, 102, 0, 0});
|
||||
test::ExpectTensorEqual<float>(expected, *mutable_input(0).tensor);
|
||||
}
|
||||
|
||||
class ScatterNdUpdateOpConstructionTest : public OpsTestBase {};
|
||||
|
||||
TEST_F(ScatterNdUpdateOpConstructionTest, Error_BadIndicesPolicyInvalid) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNd")
|
||||
.Input(FakeInput(DT_INT32))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_INT32))
|
||||
.Attr("bad_indices_policy", "AN_UNRECOGNIZED_POLICY")
|
||||
.Finalize(node_def()));
|
||||
EXPECT_NE(InitOp(), absl::OkStatus());
|
||||
}
|
||||
|
||||
class ScatterNdOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void MakeOp(DataType variable_type, DataType index_type) {
|
||||
|
|
|
|||
|
|
@ -3193,6 +3193,7 @@ REGISTER_OP("TensorScatterUpdate")
|
|||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int16, int32, int64, uint16}")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdTensorShape);
|
||||
|
||||
REGISTER_OP("TensorScatterAdd")
|
||||
|
|
@ -3202,6 +3203,7 @@ REGISTER_OP("TensorScatterAdd")
|
|||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdTensorShape);
|
||||
|
||||
REGISTER_OP("TensorScatterSub")
|
||||
|
|
@ -3211,6 +3213,7 @@ REGISTER_OP("TensorScatterSub")
|
|||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdTensorShape);
|
||||
|
||||
REGISTER_OP("TensorScatterMin")
|
||||
|
|
@ -3220,6 +3223,7 @@ REGISTER_OP("TensorScatterMin")
|
|||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdTensorShape);
|
||||
|
||||
REGISTER_OP("TensorScatterMax")
|
||||
|
|
@ -3229,6 +3233,7 @@ REGISTER_OP("TensorScatterMax")
|
|||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdTensorShape);
|
||||
|
||||
REGISTER_OP("ScatterNdNonAliasingAdd")
|
||||
|
|
@ -3238,6 +3243,7 @@ REGISTER_OP("ScatterNdNonAliasingAdd")
|
|||
.Output("output: T")
|
||||
.Attr("T: {numbertype, bool}")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdTensorShape);
|
||||
|
||||
REGISTER_OP("FakeQuantWithMinMaxArgs")
|
||||
|
|
|
|||
|
|
@ -227,6 +227,7 @@ REGISTER_OP("ScatterNdUpdate")
|
|||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = true")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("ResourceScatterNdUpdate")
|
||||
|
|
@ -236,6 +237,7 @@ REGISTER_OP("ResourceScatterNdUpdate")
|
|||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = true")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("ResourceScatterNdAdd")
|
||||
|
|
@ -245,6 +247,7 @@ REGISTER_OP("ResourceScatterNdAdd")
|
|||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = true")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("ResourceScatterNdSub")
|
||||
|
|
@ -254,6 +257,7 @@ REGISTER_OP("ResourceScatterNdSub")
|
|||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = true")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("ResourceScatterNdMin")
|
||||
|
|
@ -263,6 +267,7 @@ REGISTER_OP("ResourceScatterNdMin")
|
|||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = true")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("ResourceScatterNdMax")
|
||||
|
|
@ -272,6 +277,7 @@ REGISTER_OP("ResourceScatterNdMax")
|
|||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = true")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("ScatterNdAdd")
|
||||
|
|
@ -282,6 +288,7 @@ REGISTER_OP("ScatterNdAdd")
|
|||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("ScatterNdSub")
|
||||
|
|
@ -292,6 +299,7 @@ REGISTER_OP("ScatterNdSub")
|
|||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("ScatterNdMax")
|
||||
|
|
@ -302,6 +310,7 @@ REGISTER_OP("ScatterNdMax")
|
|||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("ScatterNdMin")
|
||||
|
|
@ -312,6 +321,7 @@ REGISTER_OP("ScatterNdMin")
|
|||
.Attr("T: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("use_locking: bool = false")
|
||||
.Attr("bad_indices_policy: string = ''")
|
||||
.SetShapeFn(ScatterNdUpdateShape);
|
||||
|
||||
REGISTER_OP("CountUpTo")
|
||||
|
|
|
|||
|
|
@ -2390,23 +2390,23 @@ tf_module {
|
|||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_add"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_add"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_max"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_min"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_sub"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_update"
|
||||
|
|
@ -2414,7 +2414,7 @@ tf_module {
|
|||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_sub"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_update"
|
||||
|
|
|
|||
|
|
@ -3858,23 +3858,23 @@ tf_module {
|
|||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdAdd"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdMax"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdMin"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdSub"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdUpdate"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterSub"
|
||||
|
|
@ -4122,27 +4122,27 @@ tf_module {
|
|||
}
|
||||
member_method {
|
||||
name: "ScatterNdAdd"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdMax"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdMin"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdNonAliasingAdd"
|
||||
argspec: "args=[\'input\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdSub"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdUpdate"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterSub"
|
||||
|
|
@ -5366,23 +5366,23 @@ tf_module {
|
|||
}
|
||||
member_method {
|
||||
name: "TensorScatterAdd"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorScatterMax"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorScatterMin"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorScatterSub"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorScatterUpdate"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorSliceDataset"
|
||||
|
|
|
|||
|
|
@ -1118,19 +1118,19 @@ tf_module {
|
|||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_add"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_max"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_min"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_sub"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "tensor_scatter_nd_update"
|
||||
|
|
|
|||
|
|
@ -3858,23 +3858,23 @@ tf_module {
|
|||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdAdd"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdMax"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdMin"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdSub"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterNdUpdate"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ResourceScatterSub"
|
||||
|
|
@ -4122,27 +4122,27 @@ tf_module {
|
|||
}
|
||||
member_method {
|
||||
name: "ScatterNdAdd"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdMax"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdMin"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdNonAliasingAdd"
|
||||
argspec: "args=[\'input\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'input\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdSub"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterNdUpdate"
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
|
||||
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "ScatterSub"
|
||||
|
|
@ -5366,23 +5366,23 @@ tf_module {
|
|||
}
|
||||
member_method {
|
||||
name: "TensorScatterAdd"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorScatterMax"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorScatterMin"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorScatterSub"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorScatterUpdate"
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "TensorSliceDataset"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user