[StableHLO->HLO] Only lower MHLO constants in MHLO prepare for export pass.

PiperOrigin-RevId: 822198262
This commit is contained in:
Kevin Gleason 2025-10-21 11:14:26 -07:00 committed by TensorFlower Gardener
parent 7524326efd
commit fe624fe9ce
2 changed files with 15 additions and 3 deletions

View File

@ -130,7 +130,8 @@ void prepareExplicitCapturedConstants(Operation* op) {
// it explicit and replace uses within the block
Operation *definingOp = input.getDefiningOp();
mlir::DenseElementsAttr attr;
if (matchPattern(input, m_Constant(&attr))) {
if (mlir::isa_and_present<ConstantOp>(input.getDefiningOp()) &&
matchPattern(input, m_Constant(&attr))) {
Operation *clonedOp = builder.clone(*definingOp);
// Find which uses belong to the block and replace
// with the cloned/explicit one
@ -146,9 +147,10 @@ void prepareExplicitCapturedConstants(Operation* op) {
} // namespace
void PrepareForExportPass::runOnOperation() {
getOperation().walk([&](Operation *op) {
getOperation().walk([&](Operation* op) {
mlir::SplatElementsAttr attr;
if (matchPattern(op, m_Constant(&attr))) return prepareConstantOp(op, attr);
if (isa<ConstantOp>(op) && matchPattern(op, m_Constant(&attr)))
return prepareConstantOp(op, attr);
if (auto bcastOp = dyn_cast<BroadcastInDimOp>(op))
return prepareBroadcastInDim(bcastOp);

View File

@ -11,6 +11,16 @@ func.func @splat_constants() -> tensor<1x64x224x224xf32> {
// -----
// CHECK-LABEL: @non_mhlo_constant
func.func @non_mhlo_constant() -> tensor<128x1014x508xcomplex<f64>> {
// CHECK: arith.constant dense<(1.000000e+00,2.000000e+00)> : tensor<128x1014x508xcomplex<f64>>
// CHECK-NOT: mhlo.broadcast_in_dim
%0 = arith.constant dense<(1.000000e+00,2.000000e+00)> : tensor<128x1014x508xcomplex<f64>>
func.return %0 : tensor<128x1014x508xcomplex<f64>>
}
// -----
// CHECK-LABEL: @splat_constant_complex_float
func.func @splat_constant_complex_float() -> tensor<128x1014x508xcomplex<f64>> {
// CHECK: %[[CST:.*]] = mhlo.constant dense<(1.000000e+00,2.000000e+00)> : tensor<complex<f64>>