Fix IfOp handling for resource lifting.

PiperOrigin-RevId: 294717587
Change-Id: Ic949434331b3ec89553114e0123ad92e4b3d2b37
This commit is contained in:
Yuanzhong Xu 2020-02-12 11:45:06 -08:00 committed by TensorFlower Gardener
parent 34f38875b1
commit f4525bf9f1
2 changed files with 15 additions and 15 deletions

View File

@ -418,15 +418,15 @@ func @launch_with_if(%arg0: tensor<i1>) -> tensor<4xf32> {
// CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]]) // CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]])
// CHECK: %[[LAUNCH:.*]]:2 = "tf_device.launch"() // CHECK: %[[LAUNCH:.*]]:2 = "tf_device.launch"()
%2 = "tf_device.launch"() ( { %2 = "tf_device.launch"() ( {
// CHECK: %[[IF:.*]] = "tf.If"(%[[ARG0]], %[[READ0]], %[[READ1]]) // CHECK: %[[IF:.*]]:2 = "tf.If"(%[[ARG0]], %[[READ0]], %[[READ1]])
%3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, %3:2 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else,
output_shapes = ["tfshape$"], is_stateless = false} output_shapes = ["tfshape$","tfshape$dim { size: 4 }"], is_stateless = false}
: (tensor<i1>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>) : (tensor<i1>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>)
-> (tensor<*x!tf.resource<tensor<4xf32>>>) -> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>)
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[IF]], %[[IF]]) // CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[IF]]#1, %[[IF]]#0)
%4 = "tf.ReadVariableOp"(%3) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32> %4 = "tf.ReadVariableOp"(%3#0) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
%5 = "tf.AddV2"(%4, %4) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %5 = "tf.AddV2"(%4, %3#1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: tf_device.return %[[ADD]], %[[IF]] // CHECK-NEXT: tf_device.return %[[ADD]], %[[IF]]#1
tf_device.return %5 : tensor<4xf32> tf_device.return %5 : tensor<4xf32>
// CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> (tensor<4xf32>, tensor<4xf32>) // CHECK: {device = "tpu0", launch_attr = "launch_attr"} : () -> (tensor<4xf32>, tensor<4xf32>)
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<4xf32> }) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<4xf32>
@ -436,21 +436,21 @@ func @launch_with_if(%arg0: tensor<i1>) -> tensor<4xf32> {
} }
// CHECK: func @if_then(%[[TARG0:.*]]: tensor<4xf32>, %[[TARG1:.*]]: tensor<4xf32>) // CHECK: func @if_then(%[[TARG0:.*]]: tensor<4xf32>, %[[TARG1:.*]]: tensor<4xf32>)
func @if_then(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.resource<tensor<4xf32>>>) func @if_then(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.resource<tensor<4xf32>>>)
-> (tensor<*x!tf.resource<tensor<4xf32>>>) { -> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) {
// CHECK-NEXT: %[[CONST:.*]] = "tf.Const"() // CHECK-NEXT: %[[CONST:.*]] = "tf.Const"()
%constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> %constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32>
"tf.AssignVariableOp"(%arg0, %constant) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> () "tf.AssignVariableOp"(%arg0, %constant) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
// CHECK-NEXT: return %[[CONST]] // CHECK-NEXT: return %[[CONST]], %[[CONST]]
return %arg0 : tensor<*x!tf.resource<tensor<4xf32>>> return %arg0, %constant : tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>
} }
// CHECK: func @if_else(%[[EARG0:.*]]: tensor<4xf32>, %[[EARG1:.*]]: tensor<4xf32>) // CHECK: func @if_else(%[[EARG0:.*]]: tensor<4xf32>, %[[EARG1:.*]]: tensor<4xf32>)
func @if_else(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.resource<tensor<4xf32>>>) func @if_else(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.resource<tensor<4xf32>>>)
-> (tensor<*x!tf.resource<tensor<4xf32>>>) { -> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) {
%id = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<*x!tf.resource<tensor<4xf32>>> %id = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<*x!tf.resource<tensor<4xf32>>>
%read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32> %read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
"tf.AssignVariableOp"(%arg0, %read) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> () "tf.AssignVariableOp"(%arg0, %read) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
// CHECK-NEXT: return %[[EARG1]] // CHECK-NEXT: return %[[EARG1]], %[[EARG1]]
return %arg0 : tensor<*x!tf.resource<tensor<4xf32>>> return %arg0, %read : tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>
} }
// ----- // -----

View File

@ -803,7 +803,7 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
// Replace uses. // Replace uses.
for (int64_t i = 0; i < old_to_new_output_indices.size(); ++i) { for (int64_t i = 0; i < old_to_new_output_indices.size(); ++i) {
if (old_to_new_output_indices[i] >= 0) { if (old_to_new_output_indices[i] >= 0) {
new_if.getResult(i).replaceAllUsesWith( if_op.getResult(i).replaceAllUsesWith(
new_if.getResult(old_to_new_output_indices[i])); new_if.getResult(old_to_new_output_indices[i]));
} }
} }