mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[XLA:CPU] Combine optimization & lowering pass managers by using callback pass.
PiperOrigin-RevId: 820610316
This commit is contained in:
parent
5da47fcdd8
commit
abc19d2d20
|
|
@ -157,11 +157,11 @@ cc_library(
|
|||
"//xla/codegen/xtile/ir:xtile",
|
||||
"//xla/mlir/tools/mlir_replay/public:compiler_trace_proto_cc",
|
||||
"//xla/mlir_hlo",
|
||||
"//xla/service/gpu/model/experimental:symbolic_expr",
|
||||
"//xla/tsl/framework/mlir:status_scoped_diagnostic_handler",
|
||||
"//xla/tsl/platform:errors",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@com_google_absl//absl/functional:any_invocable",
|
||||
"@com_google_absl//absl/functional:function_ref",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/functional/function_ref.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
|
|
@ -65,6 +66,7 @@ limitations under the License.
|
|||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Visitors.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/WalkResult.h"
|
||||
|
|
@ -99,6 +101,19 @@ limitations under the License.
|
|||
|
||||
namespace xla::cpu {
|
||||
|
||||
class ModuleCallbackPass
|
||||
: public mlir::PassWrapper<ModuleCallbackPass,
|
||||
mlir::OperationPass<mlir::ModuleOp>> {
|
||||
public:
|
||||
explicit ModuleCallbackPass(absl::FunctionRef<void(mlir::ModuleOp)> callback)
|
||||
: callback_(callback) {}
|
||||
|
||||
void runOnOperation() override { callback_(getOperation()); }
|
||||
|
||||
private:
|
||||
absl::FunctionRef<void(mlir::ModuleOp)> callback_;
|
||||
};
|
||||
|
||||
static absl::Status RunPassPipeline(
|
||||
mlir::ModuleOp module, mlir::PassManager& pm,
|
||||
mlir::interpreter::MlirCompilationTrace* trace,
|
||||
|
|
@ -273,31 +288,28 @@ FusionCompiler::FusionCompiler(mlir::MLIRContext* context, Options options,
|
|||
CompilationHooks hooks)
|
||||
: options_(std::move(options)),
|
||||
hooks_(std::move(hooks)),
|
||||
scalar_optimization_pass_manager_(
|
||||
mlir::PassManager::on<mlir::ModuleOp>(context)),
|
||||
tiled_optimization_pass_manager_(
|
||||
mlir::PassManager::on<mlir::ModuleOp>(context)),
|
||||
scalar_lowering_pass_manager_(
|
||||
mlir::PassManager::on<mlir::ModuleOp>(context)),
|
||||
tiled_lowering_pass_manager_(
|
||||
mlir::PassManager::on<mlir::ModuleOp>(context)) {
|
||||
scalar_pass_manager_(mlir::PassManager::on<mlir::ModuleOp>(context)),
|
||||
tiled_pass_manager_(mlir::PassManager::on<mlir::ModuleOp>(context)) {
|
||||
// Scalar passes.
|
||||
AddScalarOptimizationPasses(scalar_optimization_pass_manager_,
|
||||
options_.vector_width);
|
||||
AddScalarLoweringPasses(scalar_lowering_pass_manager_, options_.vector_width,
|
||||
AddScalarOptimizationPasses(scalar_pass_manager_, options_.vector_width);
|
||||
if (hooks_.post_optimization) {
|
||||
scalar_pass_manager_.addPass(
|
||||
std::make_unique<ModuleCallbackPass>(hooks_.post_optimization));
|
||||
}
|
||||
AddScalarLoweringPasses(scalar_pass_manager_, options_.vector_width,
|
||||
options_.fast_min_max);
|
||||
|
||||
// Tiled passes.
|
||||
AddTiledOptimizationPasses(tiled_optimization_pass_manager_);
|
||||
AddTiledLoweringPasses(tiled_lowering_pass_manager_);
|
||||
AddTiledOptimizationPasses(tiled_pass_manager_);
|
||||
if (hooks_.post_optimization) {
|
||||
tiled_pass_manager_.addPass(
|
||||
std::make_unique<ModuleCallbackPass>(hooks_.post_optimization));
|
||||
}
|
||||
AddTiledLoweringPasses(tiled_pass_manager_);
|
||||
|
||||
scalar_optimization_pass_manager_.addInstrumentation(
|
||||
scalar_pass_manager_.addInstrumentation(
|
||||
std::make_unique<TraceInstrumentation>());
|
||||
scalar_lowering_pass_manager_.addInstrumentation(
|
||||
std::make_unique<TraceInstrumentation>());
|
||||
tiled_optimization_pass_manager_.addInstrumentation(
|
||||
std::make_unique<TraceInstrumentation>());
|
||||
tiled_lowering_pass_manager_.addInstrumentation(
|
||||
tiled_pass_manager_.addInstrumentation(
|
||||
std::make_unique<TraceInstrumentation>());
|
||||
}
|
||||
|
||||
|
|
@ -317,11 +329,7 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> FusionCompiler::Compile(
|
|||
};
|
||||
|
||||
bool is_tiled = !mlir_module.getBody()->getOps<xtile::EntryFuncOp>().empty();
|
||||
mlir::PassManager& optimization_pm = is_tiled
|
||||
? tiled_optimization_pass_manager_
|
||||
: scalar_optimization_pass_manager_;
|
||||
mlir::PassManager& lowering_pm =
|
||||
is_tiled ? tiled_lowering_pass_manager_ : scalar_lowering_pass_manager_;
|
||||
mlir::PassManager& pm = is_tiled ? tiled_pass_manager_ : scalar_pass_manager_;
|
||||
|
||||
VLOG(1) << "Compiling MLIR module: " << module_name << ", with "
|
||||
<< get_module_op_count() << " operations.";
|
||||
|
|
@ -337,15 +345,8 @@ absl::StatusOr<std::unique_ptr<llvm::Module>> FusionCompiler::Compile(
|
|||
if (hooks_.pre_optimization) {
|
||||
hooks_.pre_optimization(mlir_module);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(RunPassPipeline(mlir_module, optimization_pm, nullptr,
|
||||
options_.verification_level));
|
||||
|
||||
if (hooks_.post_optimization) {
|
||||
hooks_.post_optimization(mlir_module);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(RunPassPipeline(mlir_module, lowering_pm, nullptr,
|
||||
options_.verification_level));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunPassPipeline(mlir_module, pm, nullptr, options_.verification_level));
|
||||
|
||||
if (hooks_.post_lowering) {
|
||||
hooks_.post_lowering(mlir_module);
|
||||
|
|
|
|||
|
|
@ -67,19 +67,11 @@ class FusionCompiler {
|
|||
private:
|
||||
Options options_;
|
||||
CompilationHooks hooks_;
|
||||
// The reason we have 4 distinct pass managers is because:
|
||||
// - We have 2 stages: optimization and lowering, this is to enable dumping
|
||||
// of the intermediate optimized MLIR.
|
||||
// - We have 2 distinct pipelines for scalar and tiled kernels, this is
|
||||
// because they differ slightly in their semantics, ideally these would be
|
||||
// unified but this is a larger change.
|
||||
// Pass manager that holds the optimization & loop transformation passes.
|
||||
mlir::PassManager scalar_optimization_pass_manager_;
|
||||
mlir::PassManager tiled_optimization_pass_manager_;
|
||||
// Pass manager that holds the passes responsible for lowering the module from
|
||||
// MLIR to LLVM.
|
||||
mlir::PassManager scalar_lowering_pass_manager_;
|
||||
mlir::PassManager tiled_lowering_pass_manager_;
|
||||
// We have 2 distinct pipelines for scalar and tiled kernels, this is
|
||||
// because they differ slightly in their semantics, ideally these would be
|
||||
// unified but this is a larger change.
|
||||
mlir::PassManager scalar_pass_manager_;
|
||||
mlir::PassManager tiled_pass_manager_;
|
||||
};
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user