[XLA:CPU] Combine optimization & lowering pass managers by using callback pass.

PiperOrigin-RevId: 820610316
This commit is contained in:
Will Froom 2025-10-17 02:57:04 -07:00 committed by TensorFlower Gardener
parent 5da47fcdd8
commit abc19d2d20
3 changed files with 40 additions and 47 deletions

View File

@ -157,11 +157,11 @@ cc_library(
"//xla/codegen/xtile/ir:xtile",
"//xla/mlir/tools/mlir_replay/public:compiler_trace_proto_cc",
"//xla/mlir_hlo",
"//xla/service/gpu/model/experimental:symbolic_expr",
"//xla/tsl/framework/mlir:status_scoped_diagnostic_handler",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include <utility>
#include "absl/functional/function_ref.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
@ -65,6 +66,7 @@ limitations under the License.
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/WalkResult.h"
@ -99,6 +101,19 @@ limitations under the License.
namespace xla::cpu {
class ModuleCallbackPass
: public mlir::PassWrapper<ModuleCallbackPass,
mlir::OperationPass<mlir::ModuleOp>> {
public:
explicit ModuleCallbackPass(absl::FunctionRef<void(mlir::ModuleOp)> callback)
: callback_(callback) {}
void runOnOperation() override { callback_(getOperation()); }
private:
absl::FunctionRef<void(mlir::ModuleOp)> callback_;
};
static absl::Status RunPassPipeline(
mlir::ModuleOp module, mlir::PassManager& pm,
mlir::interpreter::MlirCompilationTrace* trace,
@ -273,31 +288,28 @@ FusionCompiler::FusionCompiler(mlir::MLIRContext* context, Options options,
CompilationHooks hooks)
: options_(std::move(options)),
hooks_(std::move(hooks)),
scalar_optimization_pass_manager_(
mlir::PassManager::on<mlir::ModuleOp>(context)),
tiled_optimization_pass_manager_(
mlir::PassManager::on<mlir::ModuleOp>(context)),
scalar_lowering_pass_manager_(
mlir::PassManager::on<mlir::ModuleOp>(context)),
tiled_lowering_pass_manager_(
mlir::PassManager::on<mlir::ModuleOp>(context)) {
scalar_pass_manager_(mlir::PassManager::on<mlir::ModuleOp>(context)),
tiled_pass_manager_(mlir::PassManager::on<mlir::ModuleOp>(context)) {
// Scalar passes.
AddScalarOptimizationPasses(scalar_optimization_pass_manager_,
options_.vector_width);
AddScalarLoweringPasses(scalar_lowering_pass_manager_, options_.vector_width,
AddScalarOptimizationPasses(scalar_pass_manager_, options_.vector_width);
if (hooks_.post_optimization) {
scalar_pass_manager_.addPass(
std::make_unique<ModuleCallbackPass>(hooks_.post_optimization));
}
AddScalarLoweringPasses(scalar_pass_manager_, options_.vector_width,
options_.fast_min_max);
// Tiled passes.
AddTiledOptimizationPasses(tiled_optimization_pass_manager_);
AddTiledLoweringPasses(tiled_lowering_pass_manager_);
AddTiledOptimizationPasses(tiled_pass_manager_);
if (hooks_.post_optimization) {
tiled_pass_manager_.addPass(
std::make_unique<ModuleCallbackPass>(hooks_.post_optimization));
}
AddTiledLoweringPasses(tiled_pass_manager_);
scalar_optimization_pass_manager_.addInstrumentation(
scalar_pass_manager_.addInstrumentation(
std::make_unique<TraceInstrumentation>());
scalar_lowering_pass_manager_.addInstrumentation(
std::make_unique<TraceInstrumentation>());
tiled_optimization_pass_manager_.addInstrumentation(
std::make_unique<TraceInstrumentation>());
tiled_lowering_pass_manager_.addInstrumentation(
tiled_pass_manager_.addInstrumentation(
std::make_unique<TraceInstrumentation>());
}
@ -317,11 +329,7 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> FusionCompiler::Compile(
};
bool is_tiled = !mlir_module.getBody()->getOps<xtile::EntryFuncOp>().empty();
mlir::PassManager& optimization_pm = is_tiled
? tiled_optimization_pass_manager_
: scalar_optimization_pass_manager_;
mlir::PassManager& lowering_pm =
is_tiled ? tiled_lowering_pass_manager_ : scalar_lowering_pass_manager_;
mlir::PassManager& pm = is_tiled ? tiled_pass_manager_ : scalar_pass_manager_;
VLOG(1) << "Compiling MLIR module: " << module_name << ", with "
<< get_module_op_count() << " operations.";
@ -337,15 +345,8 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> FusionCompiler::Compile(
if (hooks_.pre_optimization) {
hooks_.pre_optimization(mlir_module);
}
TF_RETURN_IF_ERROR(RunPassPipeline(mlir_module, optimization_pm, nullptr,
options_.verification_level));
if (hooks_.post_optimization) {
hooks_.post_optimization(mlir_module);
}
TF_RETURN_IF_ERROR(RunPassPipeline(mlir_module, lowering_pm, nullptr,
options_.verification_level));
TF_RETURN_IF_ERROR(
RunPassPipeline(mlir_module, pm, nullptr, options_.verification_level));
if (hooks_.post_lowering) {
hooks_.post_lowering(mlir_module);

View File

@ -67,19 +67,11 @@ class FusionCompiler {
private:
Options options_;
CompilationHooks hooks_;
// The reason we have 4 distinct pass managers is because:
// - We have 2 stages: optimization and lowering, this is to enable dumping
// of the intermediate optimized MLIR.
// - We have 2 distinct pipelines for scalar and tiled kernels, this is
// because they differ slightly in their semantics, ideally these would be
// unified but this is a larger change.
// Pass manager that holds the optimization & loop transformation passes.
mlir::PassManager scalar_optimization_pass_manager_;
mlir::PassManager tiled_optimization_pass_manager_;
// Pass manager that holds the passes responsible for lowering the module from
// MLIR to LLVM.
mlir::PassManager scalar_lowering_pass_manager_;
mlir::PassManager tiled_lowering_pass_manager_;
// We have 2 distinct pipelines for scalar and tiled kernels, this is
// because they differ slightly in their semantics, ideally these would be
// unified but this is a larger change.
mlir::PassManager scalar_pass_manager_;
mlir::PassManager tiled_pass_manager_;
};
} // namespace xla::cpu