Reshard on call output if sharding mismatches with the func result.

It is no-op behaviorally for shardy. Because the call output and func result may mismatch only if dedup-functions-fully options is true, and this option is false by default.

Shardy will add explicit reshards (during shardy partitioner) on those operations that use the output of named computation and it will do so assuming the sharding of the named computation is sharded as specified in the out shardings of the named computation.

When dedup-functions-fully option is true, however, the function that is actually called may end up having a different output sharding than the corresponding named computation. So, the users of the output shardings should still use sharding as in the output shardings the named computation. Hence, if there is a mismatch between the output sharding of the named computation and the result sharding of the function, we add a reshard on the output of the call.

PiperOrigin-RevId: 823494391
This commit is contained in:
A. Unique TensorFlower 2025-10-24 05:49:46 -07:00 committed by TensorFlower Gardener
parent 0c0947cea6
commit 69c93c6f6a
3 changed files with 71 additions and 16 deletions

View File

@ -39,6 +39,7 @@ cc_library(
srcs = ["export_named_computations.cc"],
hdrs = ["export_named_computations.h"],
deps = [
"//xla/mlir_hlo",
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"@com_google_absl//absl/log:check",

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
@ -41,6 +42,7 @@ limitations under the License.
#include "shardy/dialect/sdy/ir/constants.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/service/spmd/shardy/utils.h"
@ -231,13 +233,24 @@ class ExportNamedComputationsPass
callOp->setAttrs(callOpAttrs);
// Copy the func output shardings to the call op.
// TODO(enver): Add explicit reshard if callOp and funcOp result shardings
// mismatch.
FuncOp funcOp = symbolTable.lookup<FuncOp>(funcSymName);
if (TensorShardingPerValueAttr funcResultShardings =
getFuncResultShardings(callOp, funcOp, symbolTable);
funcResultShardings) {
mlir::sdy::setShardings(callOp, funcResultShardings);
if (outShardings.has_value()) {
for (auto [funcResultSharding, outSharding, result] : llvm::zip_equal(
funcResultShardings.getShardings(),
outShardings->getShardings(), callOp.getResults())) {
if (!funcResultSharding.isEquivalent(outSharding)) {
rewriter.setInsertionPointAfterValue(result);
auto copyOp =
mlir::mhlo::CopyOp::create(rewriter, result.getLoc(), result);
mlir::sdy::setShardings(copyOp, outSharding);
rewriter.replaceAllUsesExcept(result, copyOp, copyOp);
}
}
}
if (manualAxesAttr) {
callOp->setAttr(kManualAxes, manualAxesAttr);
}
@ -256,6 +269,10 @@ class ExportNamedComputationsPass
"`NamedComputationOp`s operands/results.";
}
void getDependentDialects(mlir::DialectRegistry& registry) const final {
registry.insert<mlir::sdy::SdyDialect, mlir::mhlo::MhloDialect>();
}
Option<bool> dedupFunctionsFully{
*this, "dedup-functions-fully",
llvm::cl::desc(

View File

@ -3,10 +3,12 @@
sdy.mesh @mesh = <["x"=2, "y"=2]>
// CHECK-LABEL: func @multiple_same_named_computations_different_shardings(
func.func @multiple_same_named_computations_different_shardings(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
// CHECK-NEXT: %0 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
// CHECK-NEXT: %1 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
// CHECK-NEXT: return %1 : tensor<8x2xi32>
func.func @multiple_same_named_computations_different_shardings(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
// CHECK-NEXT: %[[CALL0:.*]] = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>}
// CHECK-NEXT: %[[CALL1:.*]] = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>}
// CHECK-NEXT: %[[COPY:.*]] = mhlo.copy %[[CALL1]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>}
// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[CALL0]], %[[COPY]]
// CHECK-NEXT: return %[[ADD]]
%0 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>] (%arg1: tensor<8x2xi32>) {
%2 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>} : tensor<8x2xi32>
sdy.return %2 : tensor<8x2xi32>
@ -15,7 +17,8 @@ func.func @multiple_same_named_computations_different_shardings(%arg0: tensor<8x
%3 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32>
sdy.return %3 : tensor<8x2xi32>
} : (tensor<8x2xi32>) -> tensor<8x2xi32>
return %1 : tensor<8x2xi32>
%4 = stablehlo.add %0, %1 : tensor<8x2xi32>
return %4 : tensor<8x2xi32>
}
// CHECK-LABEL: func private @baz(
@ -29,11 +32,12 @@ func.func @multiple_same_named_computations_different_shardings(%arg0: tensor<8x
sdy.mesh @mesh = <["x"=2, "y"=2]>
// CHECK-LABEL: func @multiple_same_named_computations_different_shardings_different_number_of_call_sites(
func.func @multiple_same_named_computations_different_shardings_different_number_of_call_sites(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
// CHECK-NEXT: %0 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
// CHECK-NEXT: %1 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
// CHECK-NEXT: %2 = call @baz(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
// CHECK-NEXT: return %2 : tensor<8x2xi32>
func.func @multiple_same_named_computations_different_shardings_different_number_of_call_sites(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
// CHECK-NEXT: %[[CALL0:.*]] = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>}
// CHECK-NEXT: %[[COPY:.*]] = mhlo.copy %[[CALL0]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>}
// CHECK-NEXT: %[[CALL1:.*]] = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>}
// CHECK-NEXT: %[[CALL2:.*]] = call @baz(%[[COPY]]) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>}
// CHECK-NEXT: return %[[CALL2]] : tensor<8x2xi32>
%0 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>] (%arg1: tensor<8x2xi32>) {
%2 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>} : tensor<8x2xi32>
sdy.return %2 : tensor<8x2xi32>
@ -59,13 +63,46 @@ func.func @multiple_same_named_computations_different_shardings_different_number
sdy.mesh @mesh = <["x"=2, "y"=2]>
// CHECK-LABEL: func @multiple_same_named_computations_multiple_outputs_different_shardings(
func.func @multiple_same_named_computations_multiple_outputs_different_shardings(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
// CHECK-NEXT: %[[CALL0:.*]]:2 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>, <@mesh, [{}, {"y"}]>]>}
// CHECK-NEXT: %[[DIVIDE0:.*]] = stablehlo.divide %[[CALL0]]#0, %[[CALL0]]#1
// CHECK-NEXT: %[[CALL1:.*]]:2 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>, <@mesh, [{}, {"y"}]>]>}
// CHECK-NEXT: %[[COPY0:.*]] = mhlo.copy %[[CALL1]]#1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {"x"}]>]>}
// CHECK-NEXT: %[[COPY1:.*]] = mhlo.copy %[[CALL1]]#0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>}
// CHECK-NEXT: %[[DIVIDE1:.*]] = stablehlo.divide %[[COPY1]], %[[COPY0]]
// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[DIVIDE0]], %[[DIVIDE1]]
// CHECK-NEXT: return %[[ADD]]
%0:2 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>, <@mesh, [{}, {"y"}]>] (%arg1: tensor<8x2xi32>) {
%5 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>} : tensor<8x2xi32>
sdy.return %5, %5 : tensor<8x2xi32>, tensor<8x2xi32>
} : (tensor<8x2xi32>) -> (tensor<8x2xi32>, tensor<8x2xi32>)
%1 = stablehlo.divide %0#0, %0#1 : tensor<8x2xi32>
%2:2 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {"y"}]>, <@mesh, [{"y"}, {"x"}]>] (%arg1: tensor<8x2xi32>) {
%5 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32>
sdy.return %5, %5 : tensor<8x2xi32>, tensor<8x2xi32>
} : (tensor<8x2xi32>) -> (tensor<8x2xi32>, tensor<8x2xi32>)
%3 = stablehlo.divide %2#0, %2#1 : tensor<8x2xi32>
%4 = stablehlo.add %1, %3 : tensor<8x2xi32>
return %4 : tensor<8x2xi32>
}
// CHECK-LABEL: func private @baz(
// CHECK-SAME: %arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>})
// CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}, tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>})
// CHECK-NEXT: stablehlo.multiply %arg0, %arg0
// CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>}
// -----
sdy.mesh @mesh = <["x"=2, "y"=2]>
// CHECK-LABEL: func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes(
// CHECK-SAME: %arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}
// CHECK-SAME: -> (tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) {
// CHECK-NEXT: %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
// CHECK-NEXT: %3 = func.call @foo(%arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %4 = func.call @foo(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: sdy.return %4 : tensor<4xf32>
// CHECK-NEXT: %5 = mhlo.copy %4 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}]>]>} : tensor<4xf32>
// CHECK-NEXT: sdy.return %5 : tensor<4xf32>
// CHECK-NEXT: } : (tensor<8xf32>) -> tensor<8xf32>
// CHECK-NEXT: %1 = call @foo_0(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : (tensor<8xf32>) -> tensor<8xf32>
// CHECK-NEXT: %2 = sdy.manual_computation(%1) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
@ -73,7 +110,7 @@ sdy.mesh @mesh = <["x"=2, "y"=2]>
// CHECK-NEXT: sdy.return %3 : tensor<4xf32>
// CHECK-NEXT: } : (tensor<8xf32>) -> tensor<8xf32>
// CHECK-NEXT: return %2 : tensor<8xf32>
func.func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes(%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) -> (tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) {
func.func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes(%arg0: tensor<8xf32>) -> tensor<8xf32> {
%0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
%1 = sdy.named_computation<"foo">(%arg1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg2: tensor<4xf32>) {
%2 = stablehlo.abs %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<4xf32>