mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Close output shardings to respect allow_spmd_sharding_propagation_to_output flag set to default {false} value. Added multiple test variants to test shardy, use_compile_options_from_model.
PiperOrigin-RevId: 784280731
This commit is contained in:
parent
026ccaa614
commit
ed20d76183
|
|
@ -75,6 +75,7 @@ cc_library(
|
|||
"//xla/service/spmd/shardy/extensions:mhlo_extensions",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AsmParser",
|
||||
"@llvm-project//mlir:FuncDialect",
|
||||
|
|
|
|||
26
third_party/xla/xla/service/spmd/shardy/utils.cc
vendored
26
third_party/xla/xla/service/spmd/shardy/utils.cc
vendored
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||
#include "mhlo/IR/register.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
|
@ -207,6 +208,31 @@ void loadAllRequiredDialects(mlir::MLIRContext* context) {
|
|||
context->loadAllAvailableDialects();
|
||||
}
|
||||
|
||||
void adjustOutputSharding(
|
||||
FuncOp func, int idx, TensorShardingAttr sharding, int64_t rank,
|
||||
absl::Span<const bool> allowSpmdShardingPropagationToOutput) {
|
||||
bool allowPropagation = false;
|
||||
if (!allowSpmdShardingPropagationToOutput.empty()) {
|
||||
allowPropagation = allowSpmdShardingPropagationToOutput.size() == 1
|
||||
? allowSpmdShardingPropagationToOutput[0]
|
||||
: allowSpmdShardingPropagationToOutput[idx];
|
||||
}
|
||||
|
||||
if (allowPropagation) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Close all dimensions if sharding propagation to outputs is not allowed.
|
||||
if (sharding) {
|
||||
sharding = sharding.getClosedLike(sharding);
|
||||
} else {
|
||||
sharding = TensorShardingAttr::getFullyClosed(
|
||||
func.getContext(), rank,
|
||||
MeshAttr::get(func.getContext(), mlir::ArrayRef<MeshAxisAttr>{}));
|
||||
}
|
||||
setFuncResultSharding(func, idx, sharding);
|
||||
}
|
||||
|
||||
CustomCallOp cloneCustomCallWithNewResultTypes(CustomCallOp op,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::IRRewriter& rewriter) {
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||
#include "absl/log/check.h"
|
||||
#include "absl/strings/escaping.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "mlir/AsmParser/AsmParser.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
|
@ -79,6 +80,12 @@ bool hasKey(mlir::DictionaryAttr dictAttr, mlir::StringRef key);
|
|||
|
||||
void loadAllRequiredDialects(mlir::MLIRContext* context);
|
||||
|
||||
// Adjusts the output sharding based on allowSpmdShardingPropagationToOutput
|
||||
// flag.
|
||||
void adjustOutputSharding(
|
||||
mlir::func::FuncOp func, int idx, mlir::sdy::TensorShardingAttr sharding,
|
||||
int64_t rank, absl::Span<const bool> allowSpmdShardingPropagationToOutput);
|
||||
|
||||
// Parses `escapedValue` to an attribute of type `AttrTy`.
|
||||
template <typename AttrTy>
|
||||
AttrTy parseStringAttr(llvm::StringRef escapedValue,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user