mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
ConcatV2 canonicalization pass should preserve the type of the axis input arg.
PiperOrigin-RevId: 365892592 Change-Id: I9a8528bef8cfb4571249e9aa8307a511a856178f
This commit is contained in:
parent
b10ae27ac0
commit
6269e15ade
|
|
@ -1169,12 +1169,22 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite(
|
|||
}
|
||||
}
|
||||
|
||||
// New lhs and rhs concatenation axis.
|
||||
auto axis_type = mlir::RankedTensorType::get({}, rewriter.getIntegerType(64));
|
||||
auto lhs_axis = rewriter.create<TF::ConstOp>(
|
||||
loc, DenseIntElementsAttr::get(axis_type, hoist_params->lhs_axis));
|
||||
auto rhs_axis = rewriter.create<TF::ConstOp>(
|
||||
loc, DenseIntElementsAttr::get(axis_type, hoist_params->rhs_axis));
|
||||
// New lhs and rhs concatenation axis
|
||||
auto axis_type =
|
||||
mlir::RankedTensorType::get({}, mlir::getElementTypeOrSelf(axis_attr));
|
||||
DenseIntElementsAttr lhs_attr, rhs_attr;
|
||||
if (axis_type.getElementType().isInteger(32)) {
|
||||
lhs_attr = DenseIntElementsAttr::get(
|
||||
axis_type, static_cast<int32_t>(hoist_params->lhs_axis));
|
||||
rhs_attr = DenseIntElementsAttr::get(
|
||||
axis_type, static_cast<int32_t>(hoist_params->rhs_axis));
|
||||
} else {
|
||||
assert(axis_type.getElementType().isInteger(64));
|
||||
lhs_attr = DenseIntElementsAttr::get(axis_type, hoist_params->lhs_axis);
|
||||
rhs_attr = DenseIntElementsAttr::get(axis_type, hoist_params->rhs_axis);
|
||||
}
|
||||
auto lhs_axis = rewriter.create<TF::ConstOp>(loc, lhs_attr);
|
||||
auto rhs_axis = rewriter.create<TF::ConstOp>(loc, rhs_attr);
|
||||
|
||||
// Concatenate binary ops operands on the new axis.
|
||||
auto lhs_concat = rewriter.create<ConcatV2Op>(
|
||||
|
|
|
|||
|
|
@ -174,8 +174,8 @@ func @testConcatCwiseUnary(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2
|
|||
func @testConcatCwiseBinaryOnInnerDim(%arg0: tensor<?x1xf32>,
|
||||
%arg1: tensor<?x1xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<?x2xf32> {
|
||||
|
||||
// CHECK: %[[LHS_AXIS:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>}
|
||||
// CHECK: %[[RHS_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>}
|
||||
// CHECK: %[[LHS_AXIS:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>}
|
||||
// CHECK: %[[RHS_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
|
||||
|
||||
// CHECK: %[[ADD_LHS_CONCAT:.*]] = "tf.ConcatV2"(%arg2, %arg3, %[[RHS_AXIS]])
|
||||
// CHECK: %[[MUL_LHS_CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %[[LHS_AXIS]])
|
||||
|
|
@ -199,6 +199,35 @@ func @testConcatCwiseBinaryOnInnerDim(%arg0: tensor<?x1xf32>,
|
|||
return %5 : tensor<?x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testConcatCwiseBinaryPreserveAxisType
|
||||
func @testConcatCwiseBinaryPreserveAxisType(%arg0: tensor<?x1xf32>,
|
||||
%arg1: tensor<?x1xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<?x2xf32> {
|
||||
|
||||
// CHECK: %[[LHS_AXIS:.*]] = "tf.Const"() {value = dense<1> : tensor<i64>}
|
||||
// CHECK: %[[RHS_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>}
|
||||
|
||||
// CHECK: %[[ADD_LHS_CONCAT:.*]] = "tf.ConcatV2"(%arg2, %arg3, %[[RHS_AXIS]])
|
||||
// CHECK: %[[MUL_LHS_CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %[[LHS_AXIS]])
|
||||
// CHECK: %[[MUL_RHS_CONCAT:.*]] = "tf.ConcatV2"(%arg2, %arg3, %[[RHS_AXIS]])
|
||||
|
||||
// CHECK: %[[MUL:.*]] = "tf.Mul"(%[[MUL_LHS_CONCAT]], %[[MUL_RHS_CONCAT]])
|
||||
// CHECK-SAME: (tensor<?x2xf32>, tensor<2xf32>) -> tensor<?x2xf32>
|
||||
// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[ADD_LHS_CONCAT]], %[[MUL]])
|
||||
// CHECK-SAME: (tensor<2xf32>, tensor<?x2xf32>) -> tensor<?x2xf32>
|
||||
// CHECK: return %[[ADD]]
|
||||
|
||||
%0 = "tf.Const"() { value = dense<1> : tensor<i64> } : () -> tensor<i64>
|
||||
// Mul of a tensor and a scalar const.
|
||||
%1 = "tf.Mul"(%arg0, %arg2) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
|
||||
%2 = "tf.Mul"(%arg1, %arg3) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
|
||||
// Add of a scalar const and a tensor.
|
||||
%3 = "tf.AddV2"(%arg2, %1) : (tensor<f32>, tensor<?x1xf32>) -> tensor<?x1xf32>
|
||||
%4 = "tf.AddV2"(%arg3, %2) : (tensor<f32>, tensor<?x1xf32>) -> tensor<?x1xf32>
|
||||
%5 = "tf.ConcatV2"(%3, %4, %0) : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<i64>) -> tensor<?x2xf32>
|
||||
|
||||
return %5 : tensor<?x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testConcatCwiseBinaryInvalidInnerDim
|
||||
func @testConcatCwiseBinaryInvalidInnerDim(%arg0: tensor<?x2xf32>,
|
||||
%arg1: tensor<?x2xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<?x4xf32> {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user