mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Fix IfOp handling for resource lifting.
PiperOrigin-RevId: 294717587 Change-Id: Ic949434331b3ec89553114e0123ad92e4b3d2b37
This commit is contained in:
parent
34f38875b1
commit
f4525bf9f1
|
|
@ -418,15 +418,15 @@ func @launch_with_if(%arg0: tensor<i1>) -> tensor<4xf32> {
|
|||
// CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]])
|
||||
// CHECK: %[[LAUNCH:.*]]:2 = "tf_device.launch"()
|
||||
%2 = "tf_device.launch"() ( {
|
||||
// CHECK: %[[IF:.*]] = "tf.If"(%[[ARG0]], %[[READ0]], %[[READ1]])
|
||||
%3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else,
|
||||
output_shapes = ["tfshape$"], is_stateless = false}
|
||||
// CHECK: %[[IF:.*]]:2 = "tf.If"(%[[ARG0]], %[[READ0]], %[[READ1]])
|
||||
%3:2 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else,
|
||||
output_shapes = ["tfshape$","tfshape$dim { size: 4 }"], is_stateless = false}
|
||||
: (tensor<i1>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>)
|
||||
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[IF]], %[[IF]])
|
||||
%4 = "tf.ReadVariableOp"(%3) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
|
||||
%5 = "tf.AddV2"(%4, %4) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: tf_device.return %[[ADD]], %[[IF]]
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>)
|
||||
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[IF]]#1, %[[IF]]#0)
|
||||
%4 = "tf.ReadVariableOp"(%3#0) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
|
||||
%5 = "tf.AddV2"(%4, %3#1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: tf_device.return %[[ADD]], %[[IF]]#1
|
||||
tf_device.return %5 : tensor<4xf32>
|
||||
// CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> (tensor<4xf32>, tensor<4xf32>)
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<4xf32>
|
||||
|
|
@ -436,21 +436,21 @@ func @launch_with_if(%arg0: tensor<i1>) -> tensor<4xf32> {
|
|||
}
|
||||
// CHECK: func @if_then(%[[TARG0:.*]]: tensor<4xf32>, %[[TARG1:.*]]: tensor<4xf32>)
|
||||
func @if_then(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.resource<tensor<4xf32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>) {
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) {
|
||||
// CHECK-NEXT: %[[CONST:.*]] = "tf.Const"()
|
||||
%constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
"tf.AssignVariableOp"(%arg0, %constant) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
|
||||
// CHECK-NEXT: return %[[CONST]]
|
||||
return %arg0 : tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
// CHECK-NEXT: return %[[CONST]], %[[CONST]]
|
||||
return %arg0, %constant : tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>
|
||||
}
|
||||
// CHECK: func @if_else(%[[EARG0:.*]]: tensor<4xf32>, %[[EARG1:.*]]: tensor<4xf32>)
|
||||
func @if_else(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.resource<tensor<4xf32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>) {
|
||||
-> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) {
|
||||
%id = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
%read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
|
||||
"tf.AssignVariableOp"(%arg0, %read) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
|
||||
// CHECK-NEXT: return %[[EARG1]]
|
||||
return %arg0 : tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
// CHECK-NEXT: return %[[EARG1]], %[[EARG1]]
|
||||
return %arg0, %read : tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
|||
|
|
@ -803,7 +803,7 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
|
|||
// Replace uses.
|
||||
for (int64_t i = 0; i < old_to_new_output_indices.size(); ++i) {
|
||||
if (old_to_new_output_indices[i] >= 0) {
|
||||
new_if.getResult(i).replaceAllUsesWith(
|
||||
if_op.getResult(i).replaceAllUsesWith(
|
||||
new_if.getResult(old_to_new_output_indices[i]));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user