mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Reshard on call output if sharding mismatches with the func result.
It is no-op behaviorally for shardy. Because the call output and func result may mismatch only if dedup-functions-fully options is true, and this option is false by default. Shardy will add explicit reshards (during shardy partitioner) on those operations that use the output of named computation and it will do so assuming the sharding of the named computation is sharded as specified in the out shardings of the named computation. When dedup-functions-fully option is true, however, the function that is actually called may end up having a different output sharding than the corresponding named computation. So, the users of the output shardings should still use sharding as in the output shardings the named computation. Hence, if there is a mismatch between the output sharding of the named computation and the result sharding of the function, we add a reshard on the output of the call. PiperOrigin-RevId: 823494391
This commit is contained in:
parent
0c0947cea6
commit
69c93c6f6a
|
|
@ -39,6 +39,7 @@ cc_library(
|
|||
srcs = ["export_named_computations.cc"],
|
||||
hdrs = ["export_named_computations.h"],
|
||||
deps = [
|
||||
"//xla/mlir_hlo",
|
||||
"//xla/service/spmd/shardy:constants",
|
||||
"//xla/service/spmd/shardy:utils",
|
||||
"@com_google_absl//absl/log:check",
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
|
|
@ -41,6 +42,7 @@ limitations under the License.
|
|||
#include "shardy/dialect/sdy/ir/constants.h"
|
||||
#include "shardy/dialect/sdy/ir/dialect.h"
|
||||
#include "shardy/dialect/sdy/ir/utils.h"
|
||||
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
|
||||
#include "xla/service/spmd/shardy/constants.h"
|
||||
#include "xla/service/spmd/shardy/utils.h"
|
||||
|
||||
|
|
@ -231,13 +233,24 @@ class ExportNamedComputationsPass
|
|||
callOp->setAttrs(callOpAttrs);
|
||||
|
||||
// Copy the func output shardings to the call op.
|
||||
// TODO(enver): Add explicit reshard if callOp and funcOp result shardings
|
||||
// mismatch.
|
||||
FuncOp funcOp = symbolTable.lookup<FuncOp>(funcSymName);
|
||||
if (TensorShardingPerValueAttr funcResultShardings =
|
||||
getFuncResultShardings(callOp, funcOp, symbolTable);
|
||||
funcResultShardings) {
|
||||
mlir::sdy::setShardings(callOp, funcResultShardings);
|
||||
if (outShardings.has_value()) {
|
||||
for (auto [funcResultSharding, outSharding, result] : llvm::zip_equal(
|
||||
funcResultShardings.getShardings(),
|
||||
outShardings->getShardings(), callOp.getResults())) {
|
||||
if (!funcResultSharding.isEquivalent(outSharding)) {
|
||||
rewriter.setInsertionPointAfterValue(result);
|
||||
auto copyOp =
|
||||
mlir::mhlo::CopyOp::create(rewriter, result.getLoc(), result);
|
||||
mlir::sdy::setShardings(copyOp, outSharding);
|
||||
rewriter.replaceAllUsesExcept(result, copyOp, copyOp);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (manualAxesAttr) {
|
||||
callOp->setAttr(kManualAxes, manualAxesAttr);
|
||||
}
|
||||
|
|
@ -256,6 +269,10 @@ class ExportNamedComputationsPass
|
|||
"`NamedComputationOp`s operands/results.";
|
||||
}
|
||||
|
||||
void getDependentDialects(mlir::DialectRegistry& registry) const final {
|
||||
registry.insert<mlir::sdy::SdyDialect, mlir::mhlo::MhloDialect>();
|
||||
}
|
||||
|
||||
Option<bool> dedupFunctionsFully{
|
||||
*this, "dedup-functions-fully",
|
||||
llvm::cl::desc(
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@
|
|||
sdy.mesh @mesh = <["x"=2, "y"=2]>
|
||||
|
||||
// CHECK-LABEL: func @multiple_same_named_computations_different_shardings(
|
||||
func.func @multiple_same_named_computations_different_shardings(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
|
||||
// CHECK-NEXT: %0 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||
// CHECK-NEXT: %1 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||
// CHECK-NEXT: return %1 : tensor<8x2xi32>
|
||||
func.func @multiple_same_named_computations_different_shardings(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
|
||||
// CHECK-NEXT: %[[CALL0:.*]] = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>}
|
||||
// CHECK-NEXT: %[[CALL1:.*]] = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>}
|
||||
// CHECK-NEXT: %[[COPY:.*]] = mhlo.copy %[[CALL1]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>}
|
||||
// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[CALL0]], %[[COPY]]
|
||||
// CHECK-NEXT: return %[[ADD]]
|
||||
%0 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>] (%arg1: tensor<8x2xi32>) {
|
||||
%2 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>} : tensor<8x2xi32>
|
||||
sdy.return %2 : tensor<8x2xi32>
|
||||
|
|
@ -15,7 +17,8 @@ func.func @multiple_same_named_computations_different_shardings(%arg0: tensor<8x
|
|||
%3 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32>
|
||||
sdy.return %3 : tensor<8x2xi32>
|
||||
} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||
return %1 : tensor<8x2xi32>
|
||||
%4 = stablehlo.add %0, %1 : tensor<8x2xi32>
|
||||
return %4 : tensor<8x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func private @baz(
|
||||
|
|
@ -29,11 +32,12 @@ func.func @multiple_same_named_computations_different_shardings(%arg0: tensor<8x
|
|||
sdy.mesh @mesh = <["x"=2, "y"=2]>
|
||||
|
||||
// CHECK-LABEL: func @multiple_same_named_computations_different_shardings_different_number_of_call_sites(
|
||||
func.func @multiple_same_named_computations_different_shardings_different_number_of_call_sites(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
|
||||
// CHECK-NEXT: %0 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||
// CHECK-NEXT: %1 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||
// CHECK-NEXT: %2 = call @baz(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||
// CHECK-NEXT: return %2 : tensor<8x2xi32>
|
||||
func.func @multiple_same_named_computations_different_shardings_different_number_of_call_sites(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
|
||||
// CHECK-NEXT: %[[CALL0:.*]] = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>}
|
||||
// CHECK-NEXT: %[[COPY:.*]] = mhlo.copy %[[CALL0]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>}
|
||||
// CHECK-NEXT: %[[CALL1:.*]] = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>}
|
||||
// CHECK-NEXT: %[[CALL2:.*]] = call @baz(%[[COPY]]) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>}
|
||||
// CHECK-NEXT: return %[[CALL2]] : tensor<8x2xi32>
|
||||
%0 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>] (%arg1: tensor<8x2xi32>) {
|
||||
%2 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>} : tensor<8x2xi32>
|
||||
sdy.return %2 : tensor<8x2xi32>
|
||||
|
|
@ -59,13 +63,46 @@ func.func @multiple_same_named_computations_different_shardings_different_number
|
|||
|
||||
sdy.mesh @mesh = <["x"=2, "y"=2]>
|
||||
|
||||
// CHECK-LABEL: func @multiple_same_named_computations_multiple_outputs_different_shardings(
|
||||
func.func @multiple_same_named_computations_multiple_outputs_different_shardings(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> {
|
||||
// CHECK-NEXT: %[[CALL0:.*]]:2 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>, <@mesh, [{}, {"y"}]>]>}
|
||||
// CHECK-NEXT: %[[DIVIDE0:.*]] = stablehlo.divide %[[CALL0]]#0, %[[CALL0]]#1
|
||||
// CHECK-NEXT: %[[CALL1:.*]]:2 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>, <@mesh, [{}, {"y"}]>]>}
|
||||
// CHECK-NEXT: %[[COPY0:.*]] = mhlo.copy %[[CALL1]]#1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}, {"x"}]>]>}
|
||||
// CHECK-NEXT: %[[COPY1:.*]] = mhlo.copy %[[CALL1]]#0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>}
|
||||
// CHECK-NEXT: %[[DIVIDE1:.*]] = stablehlo.divide %[[COPY1]], %[[COPY0]]
|
||||
// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[DIVIDE0]], %[[DIVIDE1]]
|
||||
// CHECK-NEXT: return %[[ADD]]
|
||||
%0:2 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>, <@mesh, [{}, {"y"}]>] (%arg1: tensor<8x2xi32>) {
|
||||
%5 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>} : tensor<8x2xi32>
|
||||
sdy.return %5, %5 : tensor<8x2xi32>, tensor<8x2xi32>
|
||||
} : (tensor<8x2xi32>) -> (tensor<8x2xi32>, tensor<8x2xi32>)
|
||||
%1 = stablehlo.divide %0#0, %0#1 : tensor<8x2xi32>
|
||||
%2:2 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {"y"}]>, <@mesh, [{"y"}, {"x"}]>] (%arg1: tensor<8x2xi32>) {
|
||||
%5 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32>
|
||||
sdy.return %5, %5 : tensor<8x2xi32>, tensor<8x2xi32>
|
||||
} : (tensor<8x2xi32>) -> (tensor<8x2xi32>, tensor<8x2xi32>)
|
||||
%3 = stablehlo.divide %2#0, %2#1 : tensor<8x2xi32>
|
||||
%4 = stablehlo.add %1, %3 : tensor<8x2xi32>
|
||||
return %4 : tensor<8x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func private @baz(
|
||||
// CHECK-SAME: %arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>})
|
||||
// CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}, tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>})
|
||||
// CHECK-NEXT: stablehlo.multiply %arg0, %arg0
|
||||
// CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>}
|
||||
|
||||
// -----
|
||||
|
||||
sdy.mesh @mesh = <["x"=2, "y"=2]>
|
||||
|
||||
// CHECK-LABEL: func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes(
|
||||
// CHECK-SAME: %arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}
|
||||
// CHECK-SAME: -> (tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) {
|
||||
// CHECK-NEXT: %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
|
||||
// CHECK-NEXT: %3 = func.call @foo(%arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: %4 = func.call @foo(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: sdy.return %4 : tensor<4xf32>
|
||||
// CHECK-NEXT: %5 = mhlo.copy %4 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}]>]>} : tensor<4xf32>
|
||||
// CHECK-NEXT: sdy.return %5 : tensor<4xf32>
|
||||
// CHECK-NEXT: } : (tensor<8xf32>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: %1 = call @foo_0(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : (tensor<8xf32>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: %2 = sdy.manual_computation(%1) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
|
||||
|
|
@ -73,7 +110,7 @@ sdy.mesh @mesh = <["x"=2, "y"=2]>
|
|||
// CHECK-NEXT: sdy.return %3 : tensor<4xf32>
|
||||
// CHECK-NEXT: } : (tensor<8xf32>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: return %2 : tensor<8xf32>
|
||||
func.func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes(%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) -> (tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) {
|
||||
func.func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes(%arg0: tensor<8xf32>) -> tensor<8xf32> {
|
||||
%0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
|
||||
%1 = sdy.named_computation<"foo">(%arg1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg2: tensor<4xf32>) {
|
||||
%2 = stablehlo.abs %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<4xf32>
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user