support bad_indices_policy for TensorScatterOp and ScatterNdUpdateOp

This CL introduces the bad_indices_policy as ScatterNd already had.

PiperOrigin-RevId: 654727734
This commit is contained in:
A. Unique TensorFlower 2024-07-22 06:53:04 -07:00 committed by TensorFlower Gardener
parent f24daf98fc
commit 14725f749c
21 changed files with 460 additions and 124 deletions

View File

@ -37,6 +37,9 @@
* Add support for `stablehlo.composite`.
* `EmbeddingLookup` op supports `TensorType_INT4` values.
* `tf.tensor_scatter_update`, `tf.tensor_scatter_add` and of other reduce types.
* Support `bad_indices_policy`.
## Keras
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>

View File

@ -626,7 +626,7 @@ func.func @scatter_nd_add(%arg0: tensor<7xi64>, %arg1: tensor<1x1xi32>, %arg2: t
// CHECK-LABEL: scatter_nd_add
// CHECK: %[[GATHER:.*]] = "tf.GatherNd"(%arg0, %arg1) <{bad_indices_policy = ""}> : (tensor<7xi64>, tensor<1x1xi32>) -> tensor<1xi64>
// CHECK: %[[ADD:.*]] = "tf.Add"(%arg2, %[[GATHER]]) : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
// CHECK: %[[SCATTER:.*]] = "tf.TensorScatterUpdate"(%arg0, %arg1, %[[ADD]]) : (tensor<7xi64>, tensor<1x1xi32>, tensor<1xi64>) -> tensor<7xi64>
// CHECK: %[[SCATTER:.*]] = "tf.TensorScatterUpdate"(%arg0, %arg1, %[[ADD]]) <{bad_indices_policy = ""}> : (tensor<7xi64>, tensor<1x1xi32>, tensor<1xi64>) -> tensor<7xi64>
// CHECK: return %[[SCATTER]] : tensor<7xi64>
}

View File

@ -24,7 +24,7 @@ def DenseElementsAttr : ElementsAttrBase<
"non-opaque constant tensor">;
def CreateGatherNdOp : NativeCodeCall<
"$_builder.create<TF::GatherNdOp>($0.getLoc(), $0.getType(), $1, $2)">;
"$_builder.create<TF::GatherNdOp>($0.getLoc(), $0.getType(), $1, $2, $3)">;
def CreateTFCastOpI32 : NativeCodeCall<
"CreateTFCastOpI32(&$_builder, $_loc, $0, $1)">;
@ -197,9 +197,14 @@ def ConvertMatrixSetDiagV3ToMatrixSetDiag :
//===----------------------------------------------------------------------===//
def LowerTensorScatterAdd: Pat<
(TF_TensorScatterAddOp $input, $indices, $updates),
(TF_TensorScatterUpdateOp $input, $indices,
(TF_AddOp $updates, (CreateGatherNdOp $updates, $input, $indices)))>;
(TF_TensorScatterAddOp $input, $indices, $updates, $bad_indices_policy),
(TF_TensorScatterUpdateOp
$input,
$indices,
(TF_AddOp
$updates,
(CreateGatherNdOp $updates, $input, $indices, $bad_indices_policy)),
$bad_indices_policy)>;
//===----------------------------------------------------------------------===//
// AddV2 op patterns.

View File

@ -26,7 +26,7 @@ func.func @xla_gather(%arg0: tensor<?x2xf32>, %arg1: tensor<1xi32>, %arg2: tenso
// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<1> : tensor<1x1xi64>}> : () -> tensor<1x1xi64>
// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi64>}> : () -> tensor<1xi64>
// CHECK: %[[arg1_i64:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi64>
// CHECK: %[[tensor_scatter_update:.*]] = "tf.TensorScatterUpdate"(%[[cst]], %[[cst_0]], %[[arg1_i64]]) : (tensor<2xi64>, tensor<1x1xi64>, tensor<1xi64>) -> tensor<2xi64>
// CHECK: %[[tensor_scatter_update:.*]] = "tf.TensorScatterUpdate"(%[[cst]], %[[cst_0]], %[[arg1_i64]]) <{bad_indices_policy = ""}> : (tensor<2xi64>, tensor<1x1xi64>, tensor<1xi64>) -> tensor<2xi64>
// CHECK: %[[arg2_i64:.*]] = "tf.Cast"(%arg2) <{Truncate = false}> : (tensor<2xi32>) -> tensor<2xi64>
// CHECK: %[[slice:.*]] = "tf.Slice"(%arg0, %[[tensor_scatter_update]], %[[arg2_i64]]) : (tensor<?x2xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32>
// CHECK: %[[reshape:.*]] = "tf.Reshape"(%[[slice]], %[[cst_1]]) : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32>
@ -51,7 +51,7 @@ func.func @xla_gather_known_output_shape(%arg0: tensor<5xi32>, %arg1: tensor<1xi
// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<0> : tensor<1x1xi64>}> : () -> tensor<1x1xi64>
// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi64>}> : () -> tensor<0xi64>
// CHECK: %[[arg1_i64:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi64>
// CHECK: %[[tensor_scatter_update:.*]] = "tf.TensorScatterUpdate"(%[[cst]], %[[cst_0]], %[[arg1_i64]]) : (tensor<1xi64>, tensor<1x1xi64>, tensor<1xi64>) -> tensor<1xi64>
// CHECK: %[[tensor_scatter_update:.*]] = "tf.TensorScatterUpdate"(%[[cst]], %[[cst_0]], %[[arg1_i64]]) <{bad_indices_policy = ""}> : (tensor<1xi64>, tensor<1x1xi64>, tensor<1xi64>) -> tensor<1xi64>
// CHECK: %[[arg2_i64:.*]] = "tf.Cast"(%arg2) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi64>
// CHECK: %[[slice:.*]] = "tf.Slice"(%arg0, %[[tensor_scatter_update]], %[[arg2_i64]]) : (tensor<5xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi32>
// CHECK: %[[reshape:.*]] = "tf.Reshape"(%[[slice]], %[[cst_1]]) : (tensor<1xi32>, tensor<0xi64>) -> tensor<i32>

View File

@ -13837,7 +13837,8 @@ A tensor of indices into ref.}]>:$indices,
Arg<TF_Tensor, [{A Tensor. Must have the same type as ref. A tensor of
values to add to ref.}]>:$updates,
DefaultValuedOptionalAttr<BoolAttr, "true">:$use_locking
DefaultValuedOptionalAttr<BoolAttr, "true">:$use_locking,
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
);
let results = (outs);
@ -13894,7 +13895,8 @@ A tensor of indices into ref.}]>:$indices,
Arg<TF_Tensor, [{A Tensor. Must have the same type as ref. A tensor of
values to add to ref.}]>:$updates,
DefaultValuedOptionalAttr<BoolAttr, "true">:$use_locking
DefaultValuedOptionalAttr<BoolAttr, "true">:$use_locking,
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
);
let results = (outs);
@ -13953,7 +13955,8 @@ A tensor of indices into ref.}]>:$indices,
Arg<TF_Tensor, [{A Tensor. Must have the same type as ref. A tensor of updated
values to add to ref.}]>:$updates,
DefaultValuedOptionalAttr<BoolAttr, "true">:$use_locking
DefaultValuedOptionalAttr<BoolAttr, "true">:$use_locking,
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
);
let results = (outs);
@ -19719,14 +19722,23 @@ numpy=array([[[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]],
[[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]],
[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]], dtype=int32)>
Note: on CPU, if an out of bound index is found, an error is returned.
On GPU, if an out of bound index is found, the index is ignored.
If `indices` contains any out-of-bound indices, depending on
`bad_indices_policy`, the op will either return an error or ignore the
out-of-bound indices. `bad_indices_policy` can be one of the following values:
1. "" or "DEFAULT": raises on CPU and ignore on GPU. This is because
historically on CPU and GPU we handle errors in different ways, and for
backward compatibility we keep the default behavior.
2. "ERROR": raises error; GPU does not support this value.
3. "IGNORE": ignore the bad indices; supported on both CPU and GPU.
}];
let arguments = (ins
Arg<TF_Tensor, [{Tensor to copy/update.}]>:$tensor,
Arg<TF_I32OrI64Tensor, [{Index tensor.}]>:$indices,
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates,
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
);
let results = (outs
@ -19763,7 +19775,9 @@ Refer to `tf.tensor_scatter_nd_update` for more details.
let arguments = (ins
Arg<TF_Tensor, [{Tensor to update.}]>:$tensor,
Arg<TF_I32OrI64Tensor, [{Index tensor.}]>:$indices,
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates,
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
);
let results = (outs
@ -19780,7 +19794,9 @@ def TF_TensorScatterMinOp : TF_Op<"TensorScatterMin", [Pure]> {
let arguments = (ins
Arg<TF_Tensor, [{Tensor to update.}]>:$tensor,
Arg<TF_I32OrI64Tensor, [{Index tensor.}]>:$indices,
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates,
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
);
let results = (outs
@ -19864,7 +19880,9 @@ On GPU, if an out of bound index is found, the index is ignored.
let arguments = (ins
Arg<TF_Tensor, [{Tensor to copy/update.}]>:$tensor,
Arg<TF_I32OrI64Tensor, [{Index tensor.}]>:$indices,
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates,
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
);
let results = (outs
@ -19889,8 +19907,6 @@ for the existing tensor cannot be re-used, a copy is made and updated.
If `indices` contains duplicates, then we pick the last update for the index.
If an out of bound index is found on CPU, an error is returned.
**WARNING**: There are some GPU specific semantics for this operation.
- If an out of bound index is found, the index is ignored.
- The order in which updates are applied is nondeterministic, so the output
@ -19914,6 +19930,15 @@ The overall shape of `updates` is:
indices.shape[:-1] + tensor.shape[indices.shape[-1]:]
```
If `indices` contains any out-of-bound indices, depending on
`bad_indices_policy`, the op will either return an error or ignore the
out-of-bound indices. `bad_indices_policy` can be one of the following values:
1. "" or "DEFAULT": raises on CPU and ignore on GPU. This is because
historically on CPU and GPU we handle errors in different ways, and for
backward compatibility we keep the default behavior.
2. "ERROR": raises error; GPU does not support this value.
3. "IGNORE": ignore the bad indices; supported on both CPU and GPU.
For usage examples see the python [tf.tensor_scatter_nd_update](
https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function
}];
@ -19921,7 +19946,9 @@ https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function
let arguments = (ins
Arg<TF_Tensor, [{Tensor to copy/update.}]>:$tensor,
Arg<TensorOf<[TF_Int16, TF_Int32, TF_Int64, TF_Uint16]>, [{Index tensor.}]>:$indices,
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates
Arg<TF_Tensor, [{Updates to scatter into output.}]>:$updates,
DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_policy
);
let results = (outs

View File

@ -616,7 +616,7 @@ func.func @decompose_resource_scatter_add_op(%indices : tensor<2x?xi32>, %update
// CHECK-DAG: [[READ:%.+]] = "tf.ReadVariableOp"([[VAR]])
// CHECK: [[EXPAND:%.+]] = "tf.ExpandDims"([[INDEX]], [[CST]]) : (tensor<2x?xi32>, tensor<i32>) -> tensor<2x?x1xi32>
// CHECK: [[TENSOR:%.+]] = "tf.TensorScatterAdd"([[READ]], [[EXPAND]], [[UPDATE]]) : (tensor<*xi32>, tensor<2x?x1xi32>, tensor<?x?x?xi32>) -> tensor<*xi32>
// CHECK: [[TENSOR:%.+]] = "tf.TensorScatterAdd"([[READ]], [[EXPAND]], [[UPDATE]]) <{bad_indices_policy = ""}> : (tensor<*xi32>, tensor<2x?x1xi32>, tensor<?x?x?xi32>) -> tensor<*xi32>
// CHECK: "tf.AssignVariableOp"([[VAR]], [[TENSOR]])
"tf.ResourceScatterAdd"(%resource, %indices, %updates) : (tensor<*x!tf_type.resource<tensor<*xi32>>>, tensor<2x?xi32>, tensor<?x?x?xi32>) -> ()
@ -637,7 +637,7 @@ func.func @decompose_resource_scatter_add_op_1d_indices(%indices : tensor<?xi32>
// CHECK-DAG: [[READ:%.+]] = "tf.ReadVariableOp"([[VAR]])
// CHECK: [[EXPAND:%.+]] = "tf.ExpandDims"([[INDEX]], [[CST]]) : (tensor<?xi32>, tensor<i32>) -> tensor<?x1xi32>
// CHECK: [[TENSOR:%.+]] = "tf.TensorScatterAdd"([[READ]], [[EXPAND]], [[UPDATE]]) : (tensor<*xi32>, tensor<?x1xi32>, tensor<?x?x?xi32>) -> tensor<*xi32>
// CHECK: [[TENSOR:%.+]] = "tf.TensorScatterAdd"([[READ]], [[EXPAND]], [[UPDATE]]) <{bad_indices_policy = ""}> : (tensor<*xi32>, tensor<?x1xi32>, tensor<?x?x?xi32>) -> tensor<*xi32>
// CHECK: "tf.AssignVariableOp"([[VAR]], [[TENSOR]])
"tf.ResourceScatterAdd"(%resource, %indices, %updates) : (tensor<*x!tf_type.resource<tensor<*xi32>>>, tensor<?xi32>, tensor<?x?x?xi32>) -> ()
@ -686,7 +686,7 @@ func.func @decompose_resource_scatter_update_op(%indices : tensor<2x?xi32>, %upd
// CHECK-DAG: [[READ:%.+]] = "tf.ReadVariableOp"([[VAR]])
// CHECK: [[EXPAND:%.+]] = "tf.ExpandDims"([[INDEX]], [[CST]]) : (tensor<2x?xi32>, tensor<i32>) -> tensor<2x?x1xi32>
// CHECK: [[TENSOR:%.+]] = "tf.TensorScatterUpdate"([[READ]], [[EXPAND]], [[UPDATE]]) : (tensor<*xi32>, tensor<2x?x1xi32>, tensor<?x?x?xi32>) -> tensor<*xi32>
// CHECK: [[TENSOR:%.+]] = "tf.TensorScatterUpdate"([[READ]], [[EXPAND]], [[UPDATE]]) <{bad_indices_policy = ""}> : (tensor<*xi32>, tensor<2x?x1xi32>, tensor<?x?x?xi32>) -> tensor<*xi32>
// CHECK: "tf.AssignVariableOp"([[VAR]], [[TENSOR]])
"tf.ResourceScatterUpdate"(%resource, %indices, %updates) : (tensor<*x!tf_type.resource<tensor<*xi32>>>, tensor<2x?xi32>, tensor<?x?x?xi32>) -> ()

View File

@ -9,10 +9,10 @@ func.func @invert_permutation(%arg0: tensor<5xi32>) -> tensor<5xi32> {
// CHECK-DAG: %[[cst_3:.*]] = "tf.Const"() <{value = dense<0> : tensor<5xi32>}> : () -> tensor<5xi32>
// CHECK-DAG: %[[INDICES:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<5xi32>, tensor<2xi32>) -> tensor<5x1xi32>
// CHECK-DAG: %[[INDICES_1:.*]] = "tf.TensorScatterAdd"(%[[cst_3]], %[[INDICES]], %[[cst_2]]) : (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
// CHECK-DAG: %[[INDICES_1:.*]] = "tf.TensorScatterAdd"(%[[cst_3]], %[[INDICES]], %[[cst_2]]) <{bad_indices_policy = ""}> : (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
// CHECK-DAG: %[[INDICES_2:.*]] = "tf.Sub"(%[[cst_1]], %[[INDICES_1]]) : (tensor<i32>, tensor<5xi32>) -> tensor<5xi32>
// CHECK-DAG: %[[INDICES_3:.*]] = "tf.Mul"(%[[INDICES_2]], %arg0) : (tensor<5xi32>, tensor<5xi32>) -> tensor<5xi32>
// CHECK-DAG: %[[INDICES_4:.*]] = "tf.TensorScatterAdd"(%[[cst_3]], %0, %[[UPDATES]]) : (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
// CHECK-DAG: %[[INDICES_4:.*]] = "tf.TensorScatterAdd"(%[[cst_3]], %0, %[[UPDATES]]) <{bad_indices_policy = ""}> : (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
// CHECK-DAG: %[[INDICES_5:.*]] = "tf.AddV2"(%[[INDICES_3]], %[[INDICES_4]]) : (tensor<5xi32>, tensor<5xi32>) -> tensor<5xi32>
%0 = "tf.InvertPermutation"(%arg0) : (tensor<5xi32>) -> tensor<5xi32>
func.return %0 : tensor<5xi32>
@ -825,7 +825,7 @@ func.func @Inv_i32(%arg0: tensor<*xi32>) -> tensor<*xi32> {
// CHECK-LABEL: @ScatterNd
func.func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> {
// CHECK: %[[ZERO:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<8xf32>}> : () -> tensor<8xf32>
// CHECK: "tf.TensorScatterAdd"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32>
// CHECK: "tf.TensorScatterAdd"(%[[ZERO]], %arg0, %arg1) <{bad_indices_policy = ""}> : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32>
%shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32>
%0 = "tf.ScatterNd"(%arg0, %arg1, %shape) : (tensor<4x1xi32>, tensor<4xf32>, tensor<1xi32>) -> tensor<8xf32>
@ -1216,10 +1216,10 @@ func.func @scatter_nd_updates(%arg0: tensor<14xf32>, %arg1: tensor<1x1xi32>, %ar
// CHECK-DAG: %[[CST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK-DAG: %[[CST0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK-DAG: %[[CST1:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<14xf32>}> : () -> tensor<14xf32>
// CHECK: %[[SCATTER:.*]] = "tf.TensorScatterAdd"(%cst_1, %arg1, %[[CST0]]) : (tensor<14xf32>, tensor<1x1xi32>, tensor<1xf32>) -> tensor<14xf32>
// CHECK: %[[SCATTER:.*]] = "tf.TensorScatterAdd"(%cst_1, %arg1, %[[CST0]]) <{bad_indices_policy = ""}> : (tensor<14xf32>, tensor<1x1xi32>, tensor<1xf32>) -> tensor<14xf32>
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[CST]], %[[SCATTER]]) : (tensor<f32>, tensor<14xf32>) -> tensor<14xf32>
// CHECK: %[[MUL:.*]] = "tf.Mul"(%[[SUB]], %arg0) : (tensor<14xf32>, tensor<14xf32>) -> tensor<14xf32>
// CHECK: %[[SCATTER1:.*]] = "tf.TensorScatterAdd"(%[[CST1]], %arg1, %arg2) : (tensor<14xf32>, tensor<1x1xi32>, tensor<1xf32>) -> tensor<14xf32>
// CHECK: %[[SCATTER1:.*]] = "tf.TensorScatterAdd"(%[[CST1]], %arg1, %arg2) <{bad_indices_policy = ""}> : (tensor<14xf32>, tensor<1x1xi32>, tensor<1xf32>) -> tensor<14xf32>
// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[MUL]], %[[SCATTER1]]) : (tensor<14xf32>, tensor<14xf32>) -> tensor<14xf32>
// CHECK: return %[[ADD]] : tensor<14xf32>
}
@ -1234,10 +1234,10 @@ func.func @scatter_nd_updates_bool(%arg0: tensor<1x24xi1>, %arg1: tensor<1x2x2xi
// CHECK-DAG: %[[CST1:.*]] = "tf.Const"() <{value = dense<0> : tensor<1x24xi32>}> : () -> tensor<1x24xi32>
// CHECK: %[[CAST0:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1x24xi1>) -> tensor<1x24xi32>
// CHECK: %[[CAST1:.*]] = "tf.Cast"(%arg2) <{Truncate = false}> : (tensor<1x2xi1>) -> tensor<1x2xi32>
// CHECK: %[[SCATTER:.*]] = "tf.TensorScatterAdd"(%[[CST1]], %arg1, %[[CST0]]) : (tensor<1x24xi32>, tensor<1x2x2xi32>, tensor<1x2xi32>) -> tensor<1x24xi32>
// CHECK: %[[SCATTER:.*]] = "tf.TensorScatterAdd"(%[[CST1]], %arg1, %[[CST0]]) <{bad_indices_policy = ""}> : (tensor<1x24xi32>, tensor<1x2x2xi32>, tensor<1x2xi32>) -> tensor<1x24xi32>
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[CST]], %[[SCATTER]]) : (tensor<i32>, tensor<1x24xi32>) -> tensor<1x24xi32>
// CHECK: %[[MUL:.*]] = "tf.Mul"(%[[SUB]], %[[CAST0]]) : (tensor<1x24xi32>, tensor<1x24xi32>) -> tensor<1x24xi32>
// CHECK: %[[SCATTER1:.*]] = "tf.TensorScatterAdd"(%[[CST1]], %arg1, %[[CAST1]]) : (tensor<1x24xi32>, tensor<1x2x2xi32>, tensor<1x2xi32>) -> tensor<1x24xi32>
// CHECK: %[[SCATTER1:.*]] = "tf.TensorScatterAdd"(%[[CST1]], %arg1, %[[CAST1]]) <{bad_indices_policy = ""}> : (tensor<1x24xi32>, tensor<1x2x2xi32>, tensor<1x2xi32>) -> tensor<1x24xi32>
// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[MUL]], %[[SCATTER1]]) : (tensor<1x24xi32>, tensor<1x24xi32>) -> tensor<1x24xi32>
// CHECK: %[[CAST2:.*]] = "tf.Cast"(%[[ADD]]) <{Truncate = false}> : (tensor<1x24xi32>) -> tensor<1x24xi1>
// CHECK: return %[[CAST2]] : tensor<1x24xi1>

View File

@ -151,7 +151,7 @@ func.func @main(%arg0: tensor<10x8x9xf32>, %arg1: tensor<5xi32>, %arg2: tensor<5
%tl = "tf.TensorListFromTensor"(%arg0, %elem_shape) : (tensor<10x8x9xf32>, tensor<2xi32>) -> tensor<!tf_type.variant<tensor<8x9xf32>>>
// CHECK: %[[IND_SHAPE:.*]] = "tf.Const"() <{value = dense<[5, 1]> : tensor<2xi32>}> : () -> tensor<2xi32>
// CHECK: %[[IND_RESHPE:.*]] = "tf.Reshape"(%[[ARG1]], %[[IND_SHAPE]]) : (tensor<5xi32>, tensor<2xi32>) -> tensor<5x1xi32>
// CHECK: %[[SC:.*]] = "tf.TensorScatterUpdate"(%[[BUFFER]], %[[IND_RESHPE]], %[[ARG2]]) : (tensor<10x8x9xf32>, tensor<5x1xi32>, tensor<5x8x9xf32>) -> tensor<10x8x9xf32>
// CHECK: %[[SC:.*]] = "tf.TensorScatterUpdate"(%[[BUFFER]], %[[IND_RESHPE]], %[[ARG2]]) <{bad_indices_policy = ""}> : (tensor<10x8x9xf32>, tensor<5x1xi32>, tensor<5x8x9xf32>) -> tensor<10x8x9xf32>
%scatter = "tf.TensorListScatterIntoExistingList"(%tl, %arg2, %arg1) : (tensor<!tf_type.variant<tensor<8x9xf32>>>, tensor<5x8x9xf32>, tensor<5xi32>) -> tensor<!tf_type.variant<tensor<8x9xf32>>>
%stack = "tf.TensorListStack"(%scatter, %elem_shape) : (tensor<!tf_type.variant<tensor<8x9xf32>>>, tensor<2xi32>) -> tensor<10x8x9xf32>
// CHECK: return %[[SC]] : tensor<10x8x9xf32>

View File

@ -38,6 +38,14 @@ def CheckHasResourceSubtype : Constraint<CPred<"HasResourceSubtype($0)">>;
def CreateConstBoolAttrFalse : NativeCodeCall<"$_builder.getBoolAttr(false)">;
def CreateTensorScatterAddOp : NativeCodeCall<
"$_builder.create<TF::TensorScatterAddOp>("
"$0.getLoc(), $0.getType(), $0, $1, $2, $_builder.getStringAttr(\"\"))">;
def CreateTensorScatterUpdateOp : NativeCodeCall<
"$_builder.create<TF::TensorScatterUpdateOp>("
"$0.getLoc(), $0.getType(), $0, $1, $2, $_builder.getStringAttr(\"\"))">;
def CreateTFReadVariableOpFromResourceHandle : NativeCodeCall<
"$_builder.create<TF::ReadVariableOp>("
"$0.getLoc(), GetResourceSubtype($1), $1)">;
@ -354,7 +362,7 @@ def DecomposeResourceScatterAdd : Pat<
(TF_ResourceScatterAddOp:$src_op $resource, $indices, $updates),
(TF_AssignVariableOp
$resource,
(TF_TensorScatterAddOp
(CreateTensorScatterAddOp
(CreateTFReadVariableOp $src_op, $updates, $resource),
(TF_ExpandDimsOp $indices,
(TF_ConstOp (GetScalarOfType<-1> $indices))),
@ -369,7 +377,7 @@ def DecomposeResourceScatterUpdate : Pat<
(TF_ResourceScatterUpdateOp:$src_op $resource, $indices, $updates),
(TF_AssignVariableOp
$resource,
(TF_TensorScatterUpdateOp
(CreateTensorScatterUpdateOp
(CreateTFReadVariableOp $src_op, $updates, $resource),
(TF_ExpandDimsOp $indices,
(TF_ConstOp (GetScalarOfType<-1> $indices))),

View File

@ -49,7 +49,12 @@ def CreateTFCastOpI32 : NativeCodeCall<
"CreateTFCastOpI32(&$_builder, $0.getLoc(), $1, $2)">;
def CreateTensorScatterNdOp : NativeCodeCall<
"$_builder.create<TF::ScatterNdOp>($0.getLoc(), $0.getType(), $1, $2, $3)">;
"$_builder.create<TF::ScatterNdOp>("
"$0.getLoc(), $0.getType(), $1, $2, $3, $4)">;
def CreateTensorScatterUpdateOp : NativeCodeCall<
"$_builder.create<TF::TensorScatterUpdateOp>("
"$0.getLoc(), $0.getType(), $0, $1, $2, $3)">;
//===----------------------------------------------------------------------===//
// Add op patterns.
@ -462,25 +467,13 @@ def LowerOnesLikeOp : LowerInitializationOp<TF_OnesLikeOp, 1>;
def LowerScatterNdOp :
Pat<(TF_ScatterNdOp $indices,
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$updates, $shape, ConstantStrAttr<StrAttr, "">),
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:
$updates,
$shape,
$bad_indices_policy),
(TF_TensorScatterAddOp
(TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))),
$indices, $updates)>;
def LowerScatterNdOpDefaultBadIndicesPolicy :
Pat<(TF_ScatterNdOp $indices,
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$updates, $shape, ConstantStrAttr<StrAttr, "DEFAULT">),
(TF_TensorScatterAddOp
(TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))),
$indices, $updates)>;
def LowerScatterNdOpIgnoreBadIndices :
Pat<(TF_ScatterNdOp $indices,
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$updates, $shape, ConstantStrAttr<StrAttr, "IGNORE">),
(TF_TensorScatterAddOp
(TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))),
$indices, $updates)>;
$indices, $updates, $bad_indices_policy)>;
//===----------------------------------------------------------------------===//
// Xdivy, Xlog1p and Xlogy op patterns.
@ -526,9 +519,9 @@ def LowerIsFiniteOp : Pat<(TF_IsFiniteOp $x),
// TensorScatterUpdate op patterns.
//===----------------------------------------------------------------------===//
def LowerTensorScatterUpdate_1 : Pat<(TF_TensorScatterUpdateOp $input, $indices, $updates),
def LowerTensorScatterUpdate_1 : Pat<(TF_TensorScatterUpdateOp$input, $indices, $updates, $bad_indices_policy),
(TF_CastOp
(TF_TensorScatterUpdateOp
(CreateTensorScatterUpdateOp
(CreateTFCastOpI32
$input,
$input,
@ -537,20 +530,29 @@ def LowerTensorScatterUpdate_1 : Pat<(TF_TensorScatterUpdateOp $input, $indices,
(CreateTFCastOpI32
$updates,
$updates,
/*truncate=*/ConstBoolAttrFalse)),
/*truncate=*/ConstBoolAttrFalse),
$bad_indices_policy),
/*truncate=*/ConstBoolAttrFalse),
[(TensorOf<[TF_Bool]> $input), (TensorOf<[TF_Bool]> $updates)] >;
def LowerTensorScatterUpdate_2 : Pat<(TF_TensorScatterUpdateOp $input, $indices, $updates),
def LowerTensorScatterUpdate_2 : Pat<(TF_TensorScatterUpdateOp $input, $indices, $updates, $bad_indices_policy),
(TF_AddV2Op
(TF_MulOp
(TF_SubOp (TF_ConstOp (GetScalarOfType<1> $updates)),
(CreateTensorScatterNdOp $input, $indices,
(TF_OnesLikeOp $updates),
(CreateTFShapeOp $input, $input, ConstBoolAttrTrue))),
(CreateTFShapeOp
$input,
$input,
ConstBoolAttrTrue),
$bad_indices_policy)),
$input),
(CreateTensorScatterNdOp $input, $indices, $updates,
(CreateTFShapeOp $input, $input, ConstBoolAttrTrue))),
(CreateTFShapeOp
$input,
$input,
ConstBoolAttrTrue),
$bad_indices_policy)),
[(TensorOf<[TF_Int, TF_Float, TF_Complex]> $updates)]>;
//===----------------------------------------------------------------------===//

View File

@ -83,7 +83,15 @@ numpy=array([[[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]],
[[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]],
[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]], dtype=int32)>
Note: on CPU, if an out of bound index is found, an error is returned.
On GPU, if an out of bound index is found, the index is ignored.
If `indices` contains any out-of-bound indices, depending on
`bad_indices_policy`, the op will either return an error or ignore the
out-of-bound indices. `bad_indices_policy` can be one of the following values:
1. "" or "DEFAULT": raises on CPU and ignore on GPU. This is because
historically on CPU and GPU we handle errors in different ways, and for
backward compatibility we keep the default behavior.
2. "ERROR": raises error; GPU does not support this value.
3. "IGNORE": ignore the bad indices; supported on both CPU and GPU.
END
}

View File

@ -35,8 +35,6 @@ for the existing tensor cannot be re-used, a copy is made and updated.
If `indices` contains duplicates, then we pick the last update for the index.
If an out of bound index is found on CPU, an error is returned.
**WARNING**: There are some GPU specific semantics for this operation.
- If an out of bound index is found, the index is ignored.
- The order in which updates are applied is nondeterministic, so the output
@ -60,6 +58,15 @@ The overall shape of `updates` is:
indices.shape[:-1] + tensor.shape[indices.shape[-1]:]
```
If `indices` contains any out-of-bound indices, depending on
`bad_indices_policy`, the op will either return an error or ignore the
out-of-bound indices. `bad_indices_policy` can be one of the following values:
1. "" or "DEFAULT": raises on CPU and ignore on GPU. This is because
historically on CPU and GPU we handle errors in different ways, and for
backward compatibility we keep the default behavior.
2. "ERROR": raises error; GPU does not support this value.
3. "IGNORE": ignore the bad indices; supported on both CPU and GPU.
For usage examples see the python [tf.tensor_scatter_nd_update](
https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function

View File

@ -5569,6 +5569,8 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)

View File

@ -79,13 +79,10 @@ bool ValidEmptyOutputShape(int64_t num_inputs, int64_t num_indices,
return (num_inputs != 0 && num_indices != 0 && num_updates != 0);
}
template <typename Device, typename T, typename Index>
class ScatterNdOp : public OpKernel {
template <typename Device>
class ScatterOpBase : public OpKernel {
public:
explicit ScatterNdOp(OpKernelConstruction* c) : OpKernel(c) {
const DataType dt = DataTypeToEnum<T>::v();
const DataType index_t = DataTypeToEnum<Index>::v();
OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt}));
explicit ScatterOpBase(OpKernelConstruction* c) : OpKernel(c) {
std::string bad_indices_policy_str;
OP_REQUIRES_OK(c,
c->GetAttr(kBadIndicesPolicyAtrr, &bad_indices_policy_str));
@ -101,6 +98,19 @@ class ScatterNdOp : public OpKernel {
}
}
protected:
BadIndicesPolicy bad_indices_policy_ = BadIndicesPolicy::kDefault;
};
template <typename Device, typename T, typename Index>
class ScatterNdOp : public ScatterOpBase<Device> {
public:
explicit ScatterNdOp(OpKernelConstruction* c) : ScatterOpBase<Device>(c) {
const DataType dt = DataTypeToEnum<T>::v();
const DataType index_t = DataTypeToEnum<Index>::v();
OP_REQUIRES_OK(c, c->MatchSignature({index_t, dt, index_t}, {dt}));
}
void Compute(OpKernelContext* c) override {
const Tensor& indices = c->input(0);
const Tensor& updates = c->input(1);
@ -163,19 +173,16 @@ class ScatterNdOp : public OpKernel {
OP_REQUIRES_OK(
c, functor::DoScatterNd<Device, T, Index, scatter_nd_op::UpdateOp::ADD>(
c, indices, updates, shape, &out, true /*allocate*/,
bad_indices_policy_));
this->bad_indices_policy_));
c->set_output(0, out);
}
private:
BadIndicesPolicy bad_indices_policy_ = BadIndicesPolicy::kDefault;
};
template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp op>
class TensorScatterOp : public OpKernel {
class TensorScatterOp : public ScatterOpBase<Device> {
public:
explicit TensorScatterOp(OpKernelConstruction* c) : OpKernel(c) {
explicit TensorScatterOp(OpKernelConstruction* c) : ScatterOpBase<Device>(c) {
const DataType dt = DataTypeToEnum<T>::v();
const DataType index_t = DataTypeToEnum<Index>::v();
OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t, dt}, {dt}));
@ -251,14 +258,14 @@ class TensorScatterOp : public OpKernel {
OP_REQUIRES_OK(c, tensorflow::functor::DoCopy(c->eigen_device<Device>(),
input, out));
OP_REQUIRES_OK(c,
functor::DoScatterNd<Device, T, Index, op>(
c, indices, updates, shape, out, false /*allocate*/));
OP_REQUIRES_OK(c, functor::DoScatterNd<Device, T, Index, op>(
c, indices, updates, shape, out, false /*allocate*/,
this->bad_indices_policy_));
} else {
// Output forwarded, so simply perform the scatter.
OP_REQUIRES_OK(c, functor::DoScatterNd<Device, T, Index, op>(
c, indices, updates, shape, forwarded_input.get(),
false /*allocate*/));
false /*allocate*/, this->bad_indices_policy_));
c->set_output(0, *forwarded_input);
}
@ -267,9 +274,10 @@ class TensorScatterOp : public OpKernel {
template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp op>
class ScatterNdUpdateOp : public OpKernel {
class ScatterNdUpdateOp : public ScatterOpBase<Device> {
public:
explicit ScatterNdUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
explicit ScatterNdUpdateOp(OpKernelConstruction* c)
: ScatterOpBase<Device>(c) {
const DataType dt = DataTypeToEnum<T>::v();
const DataType dt_ref = DataTypeToEnum<T>::ref();
const DataType index_t = DataTypeToEnum<Index>::v();
@ -344,10 +352,9 @@ class ScatterNdUpdateOp : public OpKernel {
params = *params_ptr;
}
}
OP_REQUIRES_OK(
c, functor::DoScatterNd<Device, T, Index, op>(
c, indices, updates, params_shape, &params, false /*allocate*/));
OP_REQUIRES_OK(c, functor::DoScatterNd<Device, T, Index, op>(
c, indices, updates, params_shape, &params,
false /*allocate*/, this->bad_indices_policy_));
}
};

View File

@ -13,15 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <functional>
#include <memory>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
@ -33,10 +38,182 @@ limitations under the License.
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tsl/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
class TensorScatterUpdateOpTest : public OpsTestBase {
protected:
void MakeOp(DataType variable_type, DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "TensorScatterUpdate")
.Input(FakeInput(variable_type))
.Input(FakeInput(index_type))
.Input(FakeInput(variable_type))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
};
TEST_F(TensorScatterUpdateOpTest, Simple_TwoD32) {
MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
AddInputFromArray<float>(TensorShape({3, 3}),
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
TF_ASSERT_OK(RunOpKernel());
// Check output.
Tensor params_tensor = *mutable_input(0).tensor;
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
10002, 0, 0, 0, 777, 778, 779});
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(TensorScatterUpdateOpTest, Simple_Two64) {
MakeOp(DT_FLOAT, DT_INT64);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
AddInputFromArray<int64_t>(TensorShape({3, 1}), {0, 4, 2});
AddInputFromArray<float>(TensorShape({3, 3}),
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
TF_ASSERT_OK(RunOpKernel());
// Check output
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
10002, 0, 0, 0, 777, 778, 779});
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(TensorScatterUpdateOpTest, Simple_ZeroD) {
MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 1}), {0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({1, 1}), {3});
AddInputFromArray<float>(TensorShape({1, 1}), {101});
TF_ASSERT_OK(RunOpKernel());
// Check output
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 1}));
test::FillValues<float>(&expected, {0, 0, 0, 101, 0});
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(TensorScatterUpdateOpTest, Simple_OneD) {
MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 1}), {0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
AddInputFromArray<float>(TensorShape({3, 1}), {100, 101, 102});
TF_ASSERT_OK(RunOpKernel());
// Check output
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 1}));
test::FillValues<float>(&expected, {100, 0, 102, 0, 101});
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(TensorScatterUpdateOpTest, HigherRank) {
MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({2, 3, 1}), {0, 4, 2, 1, 3, 6});
AddInputFromArray<float>(TensorShape({2, 3}), {10, 20, 30, 40, 50, 60});
TF_ASSERT_OK(RunOpKernel());
// Check output
Tensor expected(allocator(), DT_FLOAT, TensorShape({8}));
test::FillValues<float>(&expected, {10, 40, 30, 50, 20, 0, 60, 0});
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(TensorScatterUpdateOpTest, Error_IndexOutOfRange) {
MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 99, 4});
AddInputFromArray<float>(TensorShape({3, 3}),
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
Status s = RunOpKernel();
EXPECT_TRUE(absl::StrContains(
s.ToString(), "indices[1] = [99] does not index into shape [5,3]"))
<< s;
}
class TensorScatterUpdateOpErrorOnBadIndicesTest : public OpsTestBase {
protected:
void MakeOp(DataType variable_type, DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "TensorScatterUpdate")
.Input(FakeInput(variable_type))
.Input(FakeInput(index_type))
.Input(FakeInput(variable_type))
.Attr("bad_indices_policy", "ERROR")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
};
TEST_F(TensorScatterUpdateOpErrorOnBadIndicesTest, Error_IndexOutOfRange) {
MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 99, 4});
AddInputFromArray<float>(TensorShape({3, 3}),
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
Status s = RunOpKernel();
EXPECT_TRUE(absl::StrContains(
s.ToString(), "indices[1] = [99] does not index into shape [5,3]"))
<< s;
}
class TensorScatterUpdateOpIgnoreBadIndicesTest : public OpsTestBase {
protected:
void MakeOp(DataType variable_type, DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "TensorScatterUpdate")
.Input(FakeInput(variable_type))
.Input(FakeInput(index_type))
.Input(FakeInput(variable_type))
.Attr("bad_indices_policy", "IGNORE")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
};
TEST_F(TensorScatterUpdateOpIgnoreBadIndicesTest, DropOutOfRangeIndices) {
MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
// tensor: output tensor of 5x1 shape, initialized to 0.
AddInputFromArray<float>(TensorShape({5, 1}), {0, 0, 0, 0, 0});
// Put the bad index in the middle to make sure the others are still updated.
// Index: [[0], [5], [2]].
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 5, 2});
// Updates: [100, 101, 102].
AddInputFromArray<float>(TensorShape({3, 1}), {100, 101, 102});
TF_ASSERT_OK(RunOpKernel());
// Check the output.
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 1}));
// The valid index range is [0,5). Expect to drop index[1] of value "5" and
// update otuput[0] and output[2].
test::FillValues<float>(&expected, {100, 0, 102, 0, 0});
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
class ScatterNdUpdateOpTest : public OpsTestBase {
protected:
void MakeOp(DataType variable_ref_type, DataType index_type) {
@ -241,6 +418,80 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
<< s;
}
class ScatterNdUpdateOpErrorOnBadIndicesTest : public OpsTestBase {
protected:
void MakeOp(DataType variable_ref_type, DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNdUpdate")
.Input(FakeInput(variable_ref_type))
.Input(FakeInput(index_type))
.Input(FakeInput(RemoveRefType(variable_ref_type)))
.Attr("bad_indices_policy", "ERROR")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
};
TEST_F(ScatterNdUpdateOpErrorOnBadIndicesTest, Error_IndexOutOfRange) {
MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 99, 4});
AddInputFromArray<float>(TensorShape({3, 3}),
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
Status s = RunOpKernel();
EXPECT_TRUE(absl::StrContains(
s.ToString(), "indices[1] = [99] does not index into shape [5,3]"))
<< s;
}
class ScatterNdUpdateOpIgnoreBadIndicesTest : public OpsTestBase {
protected:
void MakeOp(DataType variable_ref_type, DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNdUpdate")
.Input(FakeInput(variable_ref_type))
.Input(FakeInput(index_type))
.Input(FakeInput(RemoveRefType(variable_ref_type)))
.Attr("bad_indices_policy", "IGNORE")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
};
TEST_F(ScatterNdUpdateOpIgnoreBadIndicesTest, DropOutOfRangeIndices) {
MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
// Put the bad index in the middle to make sure the others are still updated.
// ref: output tensor of 5x1 shape, initialized to 0.
AddInputFromArray<float>(TensorShape({5, 1}), {0, 0, 0, 0, 0});
// Index: [[0], [5], [2]].
AddInputFromArray<int32>(TensorShape({3, 1}), {0, 5, 2});
// Updates: [100, 101, 102].
AddInputFromArray<float>(TensorShape({3, 1}), {100, 101, 102});
TF_ASSERT_OK(RunOpKernel());
// Check the output.
Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 1}));
// The valid index range is [0,5). Expect to drop index[1] of value "5" and
// update otuput[0] and output[2].
test::FillValues<float>(&expected, {100, 0, 102, 0, 0});
test::ExpectTensorEqual<float>(expected, *mutable_input(0).tensor);
}
class ScatterNdUpdateOpConstructionTest : public OpsTestBase {};
TEST_F(ScatterNdUpdateOpConstructionTest, Error_BadIndicesPolicyInvalid) {
TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNd")
.Input(FakeInput(DT_INT32))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_INT32))
.Attr("bad_indices_policy", "AN_UNRECOGNIZED_POLICY")
.Finalize(node_def()));
EXPECT_NE(InitOp(), absl::OkStatus());
}
class ScatterNdOpTest : public OpsTestBase {
protected:
void MakeOp(DataType variable_type, DataType index_type) {

View File

@ -3193,6 +3193,7 @@ REGISTER_OP("TensorScatterUpdate")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int16, int32, int64, uint16}")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdTensorShape);
REGISTER_OP("TensorScatterAdd")
@ -3202,6 +3203,7 @@ REGISTER_OP("TensorScatterAdd")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdTensorShape);
REGISTER_OP("TensorScatterSub")
@ -3211,6 +3213,7 @@ REGISTER_OP("TensorScatterSub")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdTensorShape);
REGISTER_OP("TensorScatterMin")
@ -3220,6 +3223,7 @@ REGISTER_OP("TensorScatterMin")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdTensorShape);
REGISTER_OP("TensorScatterMax")
@ -3229,6 +3233,7 @@ REGISTER_OP("TensorScatterMax")
.Output("output: T")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdTensorShape);
REGISTER_OP("ScatterNdNonAliasingAdd")
@ -3238,6 +3243,7 @@ REGISTER_OP("ScatterNdNonAliasingAdd")
.Output("output: T")
.Attr("T: {numbertype, bool}")
.Attr("Tindices: {int32, int64}")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdTensorShape);
REGISTER_OP("FakeQuantWithMinMaxArgs")

View File

@ -227,6 +227,7 @@ REGISTER_OP("ScatterNdUpdate")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ResourceScatterNdUpdate")
@ -236,6 +237,7 @@ REGISTER_OP("ResourceScatterNdUpdate")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ResourceScatterNdAdd")
@ -245,6 +247,7 @@ REGISTER_OP("ResourceScatterNdAdd")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ResourceScatterNdSub")
@ -254,6 +257,7 @@ REGISTER_OP("ResourceScatterNdSub")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ResourceScatterNdMin")
@ -263,6 +267,7 @@ REGISTER_OP("ResourceScatterNdMin")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ResourceScatterNdMax")
@ -272,6 +277,7 @@ REGISTER_OP("ResourceScatterNdMax")
.Attr("T: type")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = true")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ScatterNdAdd")
@ -282,6 +288,7 @@ REGISTER_OP("ScatterNdAdd")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ScatterNdSub")
@ -292,6 +299,7 @@ REGISTER_OP("ScatterNdSub")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ScatterNdMax")
@ -302,6 +310,7 @@ REGISTER_OP("ScatterNdMax")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("ScatterNdMin")
@ -312,6 +321,7 @@ REGISTER_OP("ScatterNdMin")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
.Attr("bad_indices_policy: string = ''")
.SetShapeFn(ScatterNdUpdateShape);
REGISTER_OP("CountUpTo")

View File

@ -2390,23 +2390,23 @@ tf_module {
}
member_method {
name: "tensor_scatter_add"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "tensor_scatter_nd_add"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "tensor_scatter_nd_max"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "tensor_scatter_nd_min"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "tensor_scatter_nd_sub"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "tensor_scatter_nd_update"
@ -2414,7 +2414,7 @@ tf_module {
}
member_method {
name: "tensor_scatter_sub"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "tensor_scatter_update"

View File

@ -3858,23 +3858,23 @@ tf_module {
}
member_method {
name: "ResourceScatterNdAdd"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
}
member_method {
name: "ResourceScatterNdMax"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
}
member_method {
name: "ResourceScatterNdMin"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
}
member_method {
name: "ResourceScatterNdSub"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
}
member_method {
name: "ResourceScatterNdUpdate"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
}
member_method {
name: "ResourceScatterSub"
@ -4122,27 +4122,27 @@ tf_module {
}
member_method {
name: "ScatterNdAdd"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
}
member_method {
name: "ScatterNdMax"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
}
member_method {
name: "ScatterNdMin"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
}
member_method {
name: "ScatterNdNonAliasingAdd"
argspec: "args=[\'input\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "ScatterNdSub"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
}
member_method {
name: "ScatterNdUpdate"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
}
member_method {
name: "ScatterSub"
@ -5366,23 +5366,23 @@ tf_module {
}
member_method {
name: "TensorScatterAdd"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "TensorScatterMax"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "TensorScatterMin"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "TensorScatterSub"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "TensorScatterUpdate"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "TensorSliceDataset"

View File

@ -1118,19 +1118,19 @@ tf_module {
}
member_method {
name: "tensor_scatter_nd_add"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "tensor_scatter_nd_max"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "tensor_scatter_nd_min"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "tensor_scatter_nd_sub"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "tensor_scatter_nd_update"

View File

@ -3858,23 +3858,23 @@ tf_module {
}
member_method {
name: "ResourceScatterNdAdd"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
}
member_method {
name: "ResourceScatterNdMax"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
}
member_method {
name: "ResourceScatterNdMin"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
}
member_method {
name: "ResourceScatterNdSub"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
}
member_method {
name: "ResourceScatterNdUpdate"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
}
member_method {
name: "ResourceScatterSub"
@ -4122,27 +4122,27 @@ tf_module {
}
member_method {
name: "ScatterNdAdd"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
}
member_method {
name: "ScatterNdMax"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
}
member_method {
name: "ScatterNdMin"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
}
member_method {
name: "ScatterNdNonAliasingAdd"
argspec: "args=[\'input\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'input\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "ScatterNdSub"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'\', \'None\'], "
}
member_method {
name: "ScatterNdUpdate"
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'\', \'None\'], "
}
member_method {
name: "ScatterSub"
@ -5366,23 +5366,23 @@ tf_module {
}
member_method {
name: "TensorScatterAdd"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "TensorScatterMax"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "TensorScatterMin"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "TensorScatterSub"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "TensorScatterUpdate"
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'tensor\', \'indices\', \'updates\', \'bad_indices_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "TensorSliceDataset"