From abc19d2d204810116c22e6e73c95a272cbfc7fdf Mon Sep 17 00:00:00 2001 From: Will Froom Date: Fri, 17 Oct 2025 02:57:04 -0700 Subject: [PATCH] [XLA:CPU] Combine optimization & lowering pass managers by using callback pass. PiperOrigin-RevId: 820610316 --- .../xla/xla/backends/cpu/codegen/BUILD | 2 +- .../backends/cpu/codegen/fusion_compiler.cc | 67 ++++++++++--------- .../backends/cpu/codegen/fusion_compiler.h | 18 ++--- 3 files changed, 40 insertions(+), 47 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/codegen/BUILD b/third_party/xla/xla/backends/cpu/codegen/BUILD index af42391e513..787dc8f1991 100644 --- a/third_party/xla/xla/backends/cpu/codegen/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/BUILD @@ -157,11 +157,11 @@ cc_library( "//xla/codegen/xtile/ir:xtile", "//xla/mlir/tools/mlir_replay/public:compiler_trace_proto_cc", "//xla/mlir_hlo", - "//xla/service/gpu/model/experimental:symbolic_expr", "//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.cc b/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.cc index e0c2aa9c25d..766d9169dc4 100644 --- a/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.cc +++ b/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/functional/function_ref.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -65,6 +66,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/WalkResult.h" @@ -99,6 +101,19 @@ limitations under the License. namespace xla::cpu { +class ModuleCallbackPass + : public mlir::PassWrapper> { + public: + explicit ModuleCallbackPass(absl::FunctionRef callback) + : callback_(callback) {} + + void runOnOperation() override { callback_(getOperation()); } + + private: + absl::FunctionRef callback_; +}; + static absl::Status RunPassPipeline( mlir::ModuleOp module, mlir::PassManager& pm, mlir::interpreter::MlirCompilationTrace* trace, @@ -273,31 +288,28 @@ FusionCompiler::FusionCompiler(mlir::MLIRContext* context, Options options, CompilationHooks hooks) : options_(std::move(options)), hooks_(std::move(hooks)), - scalar_optimization_pass_manager_( - mlir::PassManager::on(context)), - tiled_optimization_pass_manager_( - mlir::PassManager::on(context)), - scalar_lowering_pass_manager_( - mlir::PassManager::on(context)), - tiled_lowering_pass_manager_( - mlir::PassManager::on(context)) { + scalar_pass_manager_(mlir::PassManager::on(context)), + tiled_pass_manager_(mlir::PassManager::on(context)) { // Scalar passes. - AddScalarOptimizationPasses(scalar_optimization_pass_manager_, - options_.vector_width); - AddScalarLoweringPasses(scalar_lowering_pass_manager_, options_.vector_width, + AddScalarOptimizationPasses(scalar_pass_manager_, options_.vector_width); + if (hooks_.post_optimization) { + scalar_pass_manager_.addPass( + std::make_unique(hooks_.post_optimization)); + } + AddScalarLoweringPasses(scalar_pass_manager_, options_.vector_width, options_.fast_min_max); // Tiled passes. - AddTiledOptimizationPasses(tiled_optimization_pass_manager_); - AddTiledLoweringPasses(tiled_lowering_pass_manager_); + AddTiledOptimizationPasses(tiled_pass_manager_); + if (hooks_.post_optimization) { + tiled_pass_manager_.addPass( + std::make_unique(hooks_.post_optimization)); + } + AddTiledLoweringPasses(tiled_pass_manager_); - scalar_optimization_pass_manager_.addInstrumentation( + scalar_pass_manager_.addInstrumentation( std::make_unique()); - scalar_lowering_pass_manager_.addInstrumentation( - std::make_unique()); - tiled_optimization_pass_manager_.addInstrumentation( - std::make_unique()); - tiled_lowering_pass_manager_.addInstrumentation( + tiled_pass_manager_.addInstrumentation( std::make_unique()); } @@ -317,11 +329,7 @@ absl::StatusOr> FusionCompiler::Compile( }; bool is_tiled = !mlir_module.getBody()->getOps().empty(); - mlir::PassManager& optimization_pm = is_tiled - ? tiled_optimization_pass_manager_ - : scalar_optimization_pass_manager_; - mlir::PassManager& lowering_pm = - is_tiled ? tiled_lowering_pass_manager_ : scalar_lowering_pass_manager_; + mlir::PassManager& pm = is_tiled ? tiled_pass_manager_ : scalar_pass_manager_; VLOG(1) << "Compiling MLIR module: " << module_name << ", with " << get_module_op_count() << " operations."; @@ -337,15 +345,8 @@ absl::StatusOr> FusionCompiler::Compile( if (hooks_.pre_optimization) { hooks_.pre_optimization(mlir_module); } - TF_RETURN_IF_ERROR(RunPassPipeline(mlir_module, optimization_pm, nullptr, - options_.verification_level)); - - if (hooks_.post_optimization) { - hooks_.post_optimization(mlir_module); - } - - TF_RETURN_IF_ERROR(RunPassPipeline(mlir_module, lowering_pm, nullptr, - options_.verification_level)); + TF_RETURN_IF_ERROR( + RunPassPipeline(mlir_module, pm, nullptr, options_.verification_level)); if (hooks_.post_lowering) { hooks_.post_lowering(mlir_module); diff --git a/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.h b/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.h index 5c301e1fe96..bd575a40ee5 100644 --- a/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.h +++ b/third_party/xla/xla/backends/cpu/codegen/fusion_compiler.h @@ -67,19 +67,11 @@ class FusionCompiler { private: Options options_; CompilationHooks hooks_; - // The reason we have 4 distinct pass managers is because: - // - We have 2 stages: optimization and lowering, this is to enable dumping - // of the intermediate optimized MLIR. - // - We have 2 distinct pipelines for scalar and tiled kernels, this is - // because they differ slightly in their semantics, ideally these would be - // unified but this is a larger change. - // Pass manager that holds the optimization & loop transformation passes. - mlir::PassManager scalar_optimization_pass_manager_; - mlir::PassManager tiled_optimization_pass_manager_; - // Pass manager that holds the passes responsible for lowering the module from - // MLIR to LLVM. - mlir::PassManager scalar_lowering_pass_manager_; - mlir::PassManager tiled_lowering_pass_manager_; + // We have 2 distinct pipelines for scalar and tiled kernels, this is + // because they differ slightly in their semantics, ideally these would be + // unified but this is a larger change. + mlir::PassManager scalar_pass_manager_; + mlir::PassManager tiled_pass_manager_; }; } // namespace xla::cpu