ConcatV2 canonicalization pass should preserve the type of the axis input arg.

PiperOrigin-RevId: 365892592
Change-Id: I9a8528bef8cfb4571249e9aa8307a511a856178f
This commit is contained in:
Roman Dzhabarov 2021-03-30 14:06:54 -07:00 committed by TensorFlower Gardener
parent b10ae27ac0
commit 6269e15ade
2 changed files with 47 additions and 8 deletions

View File

@ -1169,12 +1169,22 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
} }
} }
// New lhs and rhs concatenation axis. // New lhs and rhs concatenation axis
auto axis_type = mlir::RankedTensorType::get({}, rewriter.getIntegerType(64)); auto axis_type =
auto lhs_axis = rewriter.create<TF::ConstOp>( mlir::RankedTensorType::get({}, mlir::getElementTypeOrSelf(axis_attr));
loc, DenseIntElementsAttr::get(axis_type, hoist_params->lhs_axis)); DenseIntElementsAttr lhs_attr, rhs_attr;
auto rhs_axis = rewriter.create<TF::ConstOp>( if (axis_type.getElementType().isInteger(32)) {
loc, DenseIntElementsAttr::get(axis_type, hoist_params->rhs_axis)); lhs_attr = DenseIntElementsAttr::get(
axis_type, static_cast<int32_t>(hoist_params->lhs_axis));
rhs_attr = DenseIntElementsAttr::get(
axis_type, static_cast<int32_t>(hoist_params->rhs_axis));
} else {
assert(axis_type.getElementType().isInteger(64));
lhs_attr = DenseIntElementsAttr::get(axis_type, hoist_params->lhs_axis);
rhs_attr = DenseIntElementsAttr::get(axis_type, hoist_params->rhs_axis);
}
auto lhs_axis = rewriter.create<TF::ConstOp>(loc, lhs_attr);
auto rhs_axis = rewriter.create<TF::ConstOp>(loc, rhs_attr);
// Concatenate binary ops operands on the new axis. // Concatenate binary ops operands on the new axis.
auto lhs_concat = rewriter.create<ConcatV2Op>( auto lhs_concat = rewriter.create<ConcatV2Op>(

View File

@ -174,8 +174,8 @@ func @testConcatCwiseUnary(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2
func @testConcatCwiseBinaryOnInnerDim(%arg0: tensor<?x1xf32>, func @testConcatCwiseBinaryOnInnerDim(%arg0: tensor<?x1xf32>,
%arg1: tensor<?x1xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<?x2xf32> { %arg1: tensor<?x1xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<?x2xf32> {
// CHECK: %[[LHS_AXIS:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} // CHECK: %[[LHS_AXIS:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK: %[[RHS_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} // CHECK: %[[RHS_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK: %[[ADD_LHS_CONCAT:.*]] = "tf.ConcatV2"(%arg2, %arg3, %[[RHS_AXIS]]) // CHECK: %[[ADD_LHS_CONCAT:.*]] = "tf.ConcatV2"(%arg2, %arg3, %[[RHS_AXIS]])
// CHECK: %[[MUL_LHS_CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %[[LHS_AXIS]]) // CHECK: %[[MUL_LHS_CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %[[LHS_AXIS]])
@ -199,6 +199,35 @@ func @testConcatCwiseBinaryOnInnerDim(%arg0: tensor<?x1xf32>,
return %5 : tensor<?x2xf32> return %5 : tensor<?x2xf32>
} }
// CHECK-LABEL: testConcatCwiseBinaryPreserveAxisType
func @testConcatCwiseBinaryPreserveAxisType(%arg0: tensor<?x1xf32>,
%arg1: tensor<?x1xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<?x2xf32> {
// CHECK: %[[LHS_AXIS:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>}
// CHECK: %[[RHS_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>}
// CHECK: %[[ADD_LHS_CONCAT:.*]] = "tf.ConcatV2"(%arg2, %arg3, %[[RHS_AXIS]])
// CHECK: %[[MUL_LHS_CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %[[LHS_AXIS]])
// CHECK: %[[MUL_RHS_CONCAT:.*]] = "tf.ConcatV2"(%arg2, %arg3, %[[RHS_AXIS]])
// CHECK: %[[MUL:.*]] = "tf.Mul"(%[[MUL_LHS_CONCAT]], %[[MUL_RHS_CONCAT]])
// CHECK-SAME: (tensor<?x2xf32>, tensor<2xf32>) -> tensor<?x2xf32>
// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[ADD_LHS_CONCAT]], %[[MUL]])
// CHECK-SAME: (tensor<2xf32>, tensor<?x2xf32>) -> tensor<?x2xf32>
// CHECK: return %[[ADD]]
%0 = "tf.Const"() { value = dense<1> : tensor<i64> } : () -> tensor<i64>
// Mul of a tensor and a scalar const.
%1 = "tf.Mul"(%arg0, %arg2) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
%2 = "tf.Mul"(%arg1, %arg3) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
// Add of a scalar const and a tensor.
%3 = "tf.AddV2"(%arg2, %1) : (tensor<f32>, tensor<?x1xf32>) -> tensor<?x1xf32>
%4 = "tf.AddV2"(%arg3, %2) : (tensor<f32>, tensor<?x1xf32>) -> tensor<?x1xf32>
%5 = "tf.ConcatV2"(%3, %4, %0) : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<i64>) -> tensor<?x2xf32>
return %5 : tensor<?x2xf32>
}
// CHECK-LABEL: testConcatCwiseBinaryInvalidInnerDim // CHECK-LABEL: testConcatCwiseBinaryInvalidInnerDim
func @testConcatCwiseBinaryInvalidInnerDim(%arg0: tensor<?x2xf32>, func @testConcatCwiseBinaryInvalidInnerDim(%arg0: tensor<?x2xf32>,
%arg1: tensor<?x2xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<?x4xf32> { %arg1: tensor<?x2xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<?x4xf32> {