mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Refactor spmd/shardy export_named_computations and import_func_calls.
PiperOrigin-RevId: 825290126
This commit is contained in:
parent
9b433c3f5a
commit
e3549cef96
|
|
@ -58,7 +58,6 @@ cc_library(
|
||||||
srcs = ["import_func_calls.cc"],
|
srcs = ["import_func_calls.cc"],
|
||||||
hdrs = ["import_func_calls.h"],
|
hdrs = ["import_func_calls.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//xla/service/spmd/shardy:constants",
|
|
||||||
"//xla/service/spmd/shardy:utils",
|
"//xla/service/spmd/shardy:utils",
|
||||||
"@com_google_absl//absl/log:check",
|
"@com_google_absl//absl/log:check",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
|
|
|
||||||
|
|
@ -54,12 +54,12 @@ namespace {
|
||||||
using ::mlir::ArrayAttr;
|
using ::mlir::ArrayAttr;
|
||||||
using ::mlir::ModuleOp;
|
using ::mlir::ModuleOp;
|
||||||
using ::mlir::NamedAttribute;
|
using ::mlir::NamedAttribute;
|
||||||
|
using ::mlir::StringAttr;
|
||||||
using ::mlir::StringRef;
|
using ::mlir::StringRef;
|
||||||
using ::mlir::SymbolTable;
|
using ::mlir::SymbolTable;
|
||||||
using ::mlir::func::CallOp;
|
using ::mlir::func::CallOp;
|
||||||
using ::mlir::func::FuncOp;
|
using ::mlir::func::FuncOp;
|
||||||
|
|
||||||
using ::mlir::StringAttr;
|
|
||||||
using ::mlir::sdy::kShardingAttr;
|
using ::mlir::sdy::kShardingAttr;
|
||||||
using ::mlir::sdy::ManualAxesAttr;
|
using ::mlir::sdy::ManualAxesAttr;
|
||||||
using ::mlir::sdy::NamedComputationOp;
|
using ::mlir::sdy::NamedComputationOp;
|
||||||
|
|
@ -149,12 +149,11 @@ class ExportNamedComputationsPass
|
||||||
this->dedupFunctionsFully = other.dedupFunctionsFully;
|
this->dedupFunctionsFully = other.dedupFunctionsFully;
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallDenseMap<ComputationKey, StringAttr> funcCache;
|
|
||||||
|
|
||||||
void runOnOperation() final {
|
void runOnOperation() final {
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
SymbolTable symbolTable(moduleOp);
|
SymbolTable symbolTable(moduleOp);
|
||||||
mlir::Block& moduleBlock = moduleOp.getRegion().front();
|
mlir::Block& moduleBlock = moduleOp.getRegion().front();
|
||||||
|
llvm::SmallDenseMap<ComputationKey, StringAttr> funcCache;
|
||||||
|
|
||||||
if (dedupFunctionsFully) {
|
if (dedupFunctionsFully) {
|
||||||
using FuncNameKey = std::pair<StringRef, ManualAxesAttr>;
|
using FuncNameKey = std::pair<StringRef, ManualAxesAttr>;
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,6 @@ limitations under the License.
|
||||||
#include "llvm/ADT/PostOrderIterator.h"
|
#include "llvm/ADT/PostOrderIterator.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "llvm/Support/Threading.h"
|
#include "llvm/Support/Threading.h"
|
||||||
#include "mlir/Analysis/CallGraph.h"
|
#include "mlir/Analysis/CallGraph.h"
|
||||||
|
|
@ -46,7 +45,6 @@ limitations under the License.
|
||||||
#include "shardy/dialect/sdy/ir/constants.h"
|
#include "shardy/dialect/sdy/ir/constants.h"
|
||||||
#include "shardy/dialect/sdy/ir/dialect.h"
|
#include "shardy/dialect/sdy/ir/dialect.h"
|
||||||
#include "shardy/dialect/sdy/ir/utils.h"
|
#include "shardy/dialect/sdy/ir/utils.h"
|
||||||
#include "xla/service/spmd/shardy/constants.h"
|
|
||||||
#include "xla/service/spmd/shardy/utils.h"
|
#include "xla/service/spmd/shardy/utils.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
@ -65,15 +63,6 @@ using ::mlir::sdy::NamedComputationOp;
|
||||||
using ::mlir::sdy::TensorShardingAttr;
|
using ::mlir::sdy::TensorShardingAttr;
|
||||||
using ::mlir::sdy::TensorShardingPerValueAttr;
|
using ::mlir::sdy::TensorShardingPerValueAttr;
|
||||||
|
|
||||||
bool isInlineableCallOp(CallOp callOp) {
|
|
||||||
if (hasFrontendAttr(callOp, kXlaBackendConfigAttr)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto inlineableAttr =
|
|
||||||
tryGetFrontendAttr<mlir::BoolAttr>(callOp, kXlaInlineableAttr);
|
|
||||||
return !inlineableAttr || inlineableAttr->getValue();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the first non-maximal mesh on the argument shardings, if there is
|
// Returns the first non-maximal mesh on the argument shardings, if there is
|
||||||
// one. Otherwise returns `std::nullopt`.
|
// one. Otherwise returns `std::nullopt`.
|
||||||
// TODO(enver): Move to utils and potentially with a common helper that takes an
|
// TODO(enver): Move to utils and potentially with a common helper that takes an
|
||||||
|
|
@ -130,11 +119,10 @@ void importCallOp(
|
||||||
rewriter.setInsertionPoint(callOp);
|
rewriter.setInsertionPoint(callOp);
|
||||||
TensorShardingPerValueAttr callOpResultShardings =
|
TensorShardingPerValueAttr callOpResultShardings =
|
||||||
mlir::sdy::getShardingPerValue(callOp);
|
mlir::sdy::getShardingPerValue(callOp);
|
||||||
auto namedCompOp = rewriter.create<NamedComputationOp>(
|
auto namedCompOp = NamedComputationOp::create(
|
||||||
callOp->getLoc(), callOp->getResultTypes(), calleeName,
|
rewriter, callOp->getLoc(), callOp->getResultTypes(), calleeName,
|
||||||
callOp.getOperands(),
|
callOp.getOperands(),
|
||||||
/*inShardings=*/
|
/*inShardings=*/getFuncArgShardings(callOp, funcOp, symbolTable),
|
||||||
getFuncArgShardings(callOp, funcOp, symbolTable),
|
|
||||||
// TODO(b/439018088): Take func result shardings if call op result
|
// TODO(b/439018088): Take func result shardings if call op result
|
||||||
// shardings are empty.
|
// shardings are empty.
|
||||||
/*outShardings=*/
|
/*outShardings=*/
|
||||||
|
|
@ -187,7 +175,9 @@ class ImportFuncCallsPass
|
||||||
mlir::CallGraph callGraph(moduleOp);
|
mlir::CallGraph callGraph(moduleOp);
|
||||||
llvm::ReversePostOrderTraversal<const mlir::CallGraph*> rpo(&callGraph);
|
llvm::ReversePostOrderTraversal<const mlir::CallGraph*> rpo(&callGraph);
|
||||||
for (mlir::CallGraphNode* node : llvm::reverse(rpo)) {
|
for (mlir::CallGraphNode* node : llvm::reverse(rpo)) {
|
||||||
if (node->isExternal()) continue;
|
if (node->isExternal()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
node->getCallableRegion()->walk([&](CallOp op) {
|
node->getCallableRegion()->walk([&](CallOp op) {
|
||||||
importCallOp(op, calleeNameToMovedRegion, rewriter, symbolTable);
|
importCallOp(op, calleeNameToMovedRegion, rewriter, symbolTable);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user