Close output shardings to respect allow_spmd_sharding_propagation_to_output flag set to default {false} value. Added multiple test variants to test shardy, use_compile_options_from_model.

PiperOrigin-RevId: 784280731
This commit is contained in:
Kanish Anand 2025-07-17 12:51:16 -07:00 committed by TensorFlower Gardener
parent 026ccaa614
commit ed20d76183
3 changed files with 34 additions and 0 deletions

View File

@ -75,6 +75,7 @@ cc_library(
"//xla/service/spmd/shardy/extensions:mhlo_extensions",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:FuncDialect",

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mhlo/IR/register.h"
#include "absl/log/check.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
@ -207,6 +208,31 @@ void loadAllRequiredDialects(mlir::MLIRContext* context) {
context->loadAllAvailableDialects();
}
void adjustOutputSharding(
FuncOp func, int idx, TensorShardingAttr sharding, int64_t rank,
absl::Span<const bool> allowSpmdShardingPropagationToOutput) {
bool allowPropagation = false;
if (!allowSpmdShardingPropagationToOutput.empty()) {
allowPropagation = allowSpmdShardingPropagationToOutput.size() == 1
? allowSpmdShardingPropagationToOutput[0]
: allowSpmdShardingPropagationToOutput[idx];
}
if (allowPropagation) {
return;
}
// Close all dimensions if sharding propagation to outputs is not allowed.
if (sharding) {
sharding = sharding.getClosedLike(sharding);
} else {
sharding = TensorShardingAttr::getFullyClosed(
func.getContext(), rank,
MeshAttr::get(func.getContext(), mlir::ArrayRef<MeshAxisAttr>{}));
}
setFuncResultSharding(func, idx, sharding);
}
CustomCallOp cloneCustomCallWithNewResultTypes(CustomCallOp op,
mlir::TypeRange resultTypes,
mlir::IRRewriter& rewriter) {

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/log/check.h"
#include "absl/strings/escaping.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
@ -79,6 +80,12 @@ bool hasKey(mlir::DictionaryAttr dictAttr, mlir::StringRef key);
void loadAllRequiredDialects(mlir::MLIRContext* context);
// Adjusts the output sharding based on allowSpmdShardingPropagationToOutput
// flag.
void adjustOutputSharding(
mlir::func::FuncOp func, int idx, mlir::sdy::TensorShardingAttr sharding,
int64_t rank, absl::Span<const bool> allowSpmdShardingPropagationToOutput);
// Parses `escapedValue` to an attribute of type `AttrTy`.
template <typename AttrTy>
AttrTy parseStringAttr(llvm::StringRef escapedValue,