Refactor spmd/shardy export_named_computations and import_func_calls.

PiperOrigin-RevId: 825290126
This commit is contained in:
Zixuan Jiang 2025-10-28 18:14:26 -07:00 committed by TensorFlower Gardener
parent 9b433c3f5a
commit e3549cef96
3 changed files with 8 additions and 20 deletions

View File

@ -58,7 +58,6 @@ cc_library(
srcs = ["import_func_calls.cc"],
hdrs = ["import_func_calls.h"],
deps = [
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"@com_google_absl//absl/log:check",
"@llvm-project//llvm:Support",

View File

@ -54,12 +54,12 @@ namespace {
using ::mlir::ArrayAttr;
using ::mlir::ModuleOp;
using ::mlir::NamedAttribute;
using ::mlir::StringAttr;
using ::mlir::StringRef;
using ::mlir::SymbolTable;
using ::mlir::func::CallOp;
using ::mlir::func::FuncOp;
using ::mlir::StringAttr;
using ::mlir::sdy::kShardingAttr;
using ::mlir::sdy::ManualAxesAttr;
using ::mlir::sdy::NamedComputationOp;
@ -149,12 +149,11 @@ class ExportNamedComputationsPass
this->dedupFunctionsFully = other.dedupFunctionsFully;
}
llvm::SmallDenseMap<ComputationKey, StringAttr> funcCache;
void runOnOperation() final {
ModuleOp moduleOp = getOperation();
SymbolTable symbolTable(moduleOp);
mlir::Block& moduleBlock = moduleOp.getRegion().front();
llvm::SmallDenseMap<ComputationKey, StringAttr> funcCache;
if (dedupFunctionsFully) {
using FuncNameKey = std::pair<StringRef, ManualAxesAttr>;

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Threading.h"
#include "mlir/Analysis/CallGraph.h"
@ -46,7 +45,6 @@ limitations under the License.
#include "shardy/dialect/sdy/ir/constants.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/service/spmd/shardy/utils.h"
namespace xla {
@ -65,15 +63,6 @@ using ::mlir::sdy::NamedComputationOp;
using ::mlir::sdy::TensorShardingAttr;
using ::mlir::sdy::TensorShardingPerValueAttr;
bool isInlineableCallOp(CallOp callOp) {
if (hasFrontendAttr(callOp, kXlaBackendConfigAttr)) {
return false;
}
auto inlineableAttr =
tryGetFrontendAttr<mlir::BoolAttr>(callOp, kXlaInlineableAttr);
return !inlineableAttr || inlineableAttr->getValue();
}
// Returns the first non-maximal mesh on the argument shardings, if there is
// one. Otherwise returns `std::nullopt`.
// TODO(enver): Move to utils and potentially with a common helper that takes an
@ -130,11 +119,10 @@ void importCallOp(
rewriter.setInsertionPoint(callOp);
TensorShardingPerValueAttr callOpResultShardings =
mlir::sdy::getShardingPerValue(callOp);
auto namedCompOp = rewriter.create<NamedComputationOp>(
callOp->getLoc(), callOp->getResultTypes(), calleeName,
auto namedCompOp = NamedComputationOp::create(
rewriter, callOp->getLoc(), callOp->getResultTypes(), calleeName,
callOp.getOperands(),
/*inShardings=*/
getFuncArgShardings(callOp, funcOp, symbolTable),
/*inShardings=*/getFuncArgShardings(callOp, funcOp, symbolTable),
// TODO(b/439018088): Take func result shardings if call op result
// shardings are empty.
/*outShardings=*/
@ -187,7 +175,9 @@ class ImportFuncCallsPass
mlir::CallGraph callGraph(moduleOp);
llvm::ReversePostOrderTraversal<const mlir::CallGraph*> rpo(&callGraph);
for (mlir::CallGraphNode* node : llvm::reverse(rpo)) {
if (node->isExternal()) continue;
if (node->isExternal()) {
continue;
}
node->getCallableRegion()->walk([&](CallOp op) {
importCallOp(op, calleeNameToMovedRegion, rewriter, symbolTable);
});