Deduplicate functions on the one with largest number of call sites.

Instead of picking arbitrarily.

PiperOrigin-RevId: 822566069
This commit is contained in:
A. Unique TensorFlower 2025-10-22 06:44:53 -07:00 committed by TensorFlower Gardener
parent 83b84b3c46
commit 39506ad1cd
2 changed files with 104 additions and 11 deletions

View File

@ -15,9 +15,11 @@ limitations under the License.
#include "xla/service/spmd/shardy/round_trip_common/export_named_computations.h" #include "xla/service/spmd/shardy/round_trip_common/export_named_computations.h"
#include <cstdint>
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <tuple> #include <tuple>
#include <utility>
#include "absl/log/check.h" #include "absl/log/check.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
@ -151,6 +153,53 @@ class ExportNamedComputationsPass
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
SymbolTable symbolTable(moduleOp); SymbolTable symbolTable(moduleOp);
mlir::Block& moduleBlock = moduleOp.getRegion().front(); mlir::Block& moduleBlock = moduleOp.getRegion().front();
if (dedupFunctionsFully) {
llvm::SmallDenseMap<ComputationKey, int64_t> funcCallSiteCounts;
llvm::SmallDenseMap<std::pair<StringRef, ManualAxesAttr>,
std::pair<NamedComputationOp, int64_t>>
funcToNamedComputations;
moduleOp.walk([&](NamedComputationOp namedComputationOp) {
ManualAxesAttr manualAxesAttr =
namedComputationOp->getAttrOfType<ManualAxesAttr>(kManualAxes);
auto key =
std::make_tuple(namedComputationOp.getName(),
namedComputationOp.getInShardings().value_or(
TensorShardingPerValueAttr()),
namedComputationOp.getOutShardings().value_or(
TensorShardingPerValueAttr()),
manualAxesAttr);
const int64_t callSiteCount = funcCallSiteCounts[key]++;
if (auto [it, inserted] = funcToNamedComputations.try_emplace(
std::pair(namedComputationOp.getName(), manualAxesAttr),
namedComputationOp, callSiteCount);
!inserted) {
auto& [cachedNamedComputationOp, cachedCallSiteCount] = it->second;
if (callSiteCount > cachedCallSiteCount) {
cachedNamedComputationOp = namedComputationOp;
cachedCallSiteCount = callSiteCount;
}
}
});
for (auto& [_, namedComputationCountPair] : funcToNamedComputations) {
auto& [namedComputationOp, callSiteCount] = namedComputationCountPair;
mlir::IRRewriter rewriter(namedComputationOp);
rewriter.setInsertionPointToEnd(&moduleBlock);
ManualAxesAttr manualAxesAttr =
namedComputationOp->getAttrOfType<ManualAxesAttr>(kManualAxes);
StringAttr funcSymName =
createFuncOp(namedComputationOp, rewriter, symbolTable,
namedComputationOp.getInShardings(),
namedComputationOp.getOutShardings(), manualAxesAttr);
funcCache.try_emplace(
std::make_tuple(namedComputationOp.getName(),
TensorShardingPerValueAttr(),
TensorShardingPerValueAttr(), manualAxesAttr),
funcSymName);
}
}
// NOTE: The walk needs to be in post order, which is the default order, to // NOTE: The walk needs to be in post order, which is the default order, to
// account for nested named computations. // account for nested named computations.
moduleOp.walk([&](NamedComputationOp namedComputationOp) { moduleOp.walk([&](NamedComputationOp namedComputationOp) {
@ -210,10 +259,10 @@ class ExportNamedComputationsPass
Option<bool> dedupFunctionsFully{ Option<bool> dedupFunctionsFully{
*this, "dedup-functions-fully", *this, "dedup-functions-fully",
llvm::cl::desc( llvm::cl::desc(
"Whether to deduplicate functions fully, regardless of the input and " "If true, regardless of the input and output shardings of functions, "
"output shardings of functions, and it keeps one callee function for " "it keeps one callee function for each caller function. The default "
"each caller function. The default is false, meaning it will " "is false, meaning it will deduplicate only if the input and output "
"deduplicate only if the input and output shardings are the same."), "shardings are the same."),
llvm::cl::init(false)}; llvm::cl::init(false)};
}; };

View File

@ -8,7 +8,7 @@ func.func @multiple_same_named_computations_different_shardings(%arg0: tensor<8x
// CHECK-NEXT: %1 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32> // CHECK-NEXT: %1 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
// CHECK-NEXT: return %1 : tensor<8x2xi32> // CHECK-NEXT: return %1 : tensor<8x2xi32>
%0 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>] (%arg1: tensor<8x2xi32>) { %0 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>] (%arg1: tensor<8x2xi32>) {
%2 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32> %2 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>} : tensor<8x2xi32>
sdy.return %2 : tensor<8x2xi32> sdy.return %2 : tensor<8x2xi32>
} : (tensor<8x2xi32>) -> tensor<8x2xi32> } : (tensor<8x2xi32>) -> tensor<8x2xi32>
%1 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {"y"}]>] (%arg1: tensor<8x2xi32>) { %1 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {"y"}]>] (%arg1: tensor<8x2xi32>) {
@ -21,6 +21,39 @@ func.func @multiple_same_named_computations_different_shardings(%arg0: tensor<8x
// CHECK-LABEL: func private @baz( // CHECK-LABEL: func private @baz(
// CHECK-SAME: %arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}) // CHECK-SAME: %arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>})
// CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) // CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>})
// CHECK-NEXT: stablehlo.multiply %arg0, %arg0
// CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>}
// -----
sdy.mesh @mesh = <["x"=2, "y"=2]>
// CHECK-LABEL: func @multiple_same_named_computations_different_shardings_different_number_of_call_sites(
func.func @multiple_same_named_computations_different_shardings_different_number_of_call_sites(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
// CHECK-NEXT: %0 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
// CHECK-NEXT: %1 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
// CHECK-NEXT: %2 = call @baz(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
// CHECK-NEXT: return %2 : tensor<8x2xi32>
%0 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>] (%arg1: tensor<8x2xi32>) {
%2 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>} : tensor<8x2xi32>
sdy.return %2 : tensor<8x2xi32>
} : (tensor<8x2xi32>) -> tensor<8x2xi32>
%1 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {"y"}]>] (%arg1: tensor<8x2xi32>) {
%3 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32>
sdy.return %3 : tensor<8x2xi32>
} : (tensor<8x2xi32>) -> tensor<8x2xi32>
%2 = sdy.named_computation<"baz">(%0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {"y"}]>] (%arg1: tensor<8x2xi32>) {
%3 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32>
sdy.return %3 : tensor<8x2xi32>
} : (tensor<8x2xi32>) -> tensor<8x2xi32>
return %2 : tensor<8x2xi32>
}
// CHECK-LABEL: func private @baz(
// CHECK-SAME: %arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>})
// CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>})
// CHECK-NEXT: stablehlo.multiply %arg0, %arg0
// CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>}
// ----- // -----
@ -29,13 +62,17 @@ sdy.mesh @mesh = <["x"=2, "y"=2]>
// CHECK-LABEL: func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes( // CHECK-LABEL: func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes(
// CHECK-SAME: %arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>} // CHECK-SAME: %arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}
// CHECK-SAME: -> (tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) { // CHECK-SAME: -> (tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) {
// CHECK: %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) { // CHECK-NEXT: %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
// CHECK-NEXT: %2 = func.call @foo(%arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %3 = func.call @foo(%arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %3 = func.call @foo(%2) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %4 = func.call @foo(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: sdy.return %3 : tensor<4xf32> // CHECK-NEXT: sdy.return %4 : tensor<4xf32>
// CHECK-NEXT: } : (tensor<8xf32>) -> tensor<8xf32> // CHECK-NEXT: } : (tensor<8xf32>) -> tensor<8xf32>
// CHECK-NEXT: %1 = call @foo_0(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : (tensor<8xf32>) -> tensor<8xf32> // CHECK-NEXT: %1 = call @foo_0(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : (tensor<8xf32>) -> tensor<8xf32>
// CHECK-NEXT: return %1 : tensor<8xf32> // CHECK-NEXT: %2 = sdy.manual_computation(%1) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
// CHECK-NEXT: %3 = func.call @foo(%arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: sdy.return %3 : tensor<4xf32>
// CHECK-NEXT: } : (tensor<8xf32>) -> tensor<8xf32>
// CHECK-NEXT: return %2 : tensor<8xf32>
func.func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes(%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) -> (tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) { func.func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes(%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) -> (tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) {
%0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) { %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
%1 = sdy.named_computation<"foo">(%arg1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg2: tensor<4xf32>) { %1 = sdy.named_computation<"foo">(%arg1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg2: tensor<4xf32>) {
@ -52,7 +89,14 @@ func.func @named_computations_same_funcs_two_same_manual_axes_different_sharding
%6 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<8xf32> %6 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<8xf32>
sdy.return %6 : tensor<8xf32> sdy.return %6 : tensor<8xf32>
} : (tensor<8xf32>) -> tensor<8xf32> } : (tensor<8xf32>) -> tensor<8xf32>
return %5 : tensor<8xf32> %7 = sdy.manual_computation(%5) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
%1 = sdy.named_computation<"foo">(%arg1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg2: tensor<4xf32>) {
%2 = stablehlo.abs %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<4xf32>
sdy.return %2 : tensor<4xf32>
} {xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
sdy.return %1 : tensor<4xf32>
} : (tensor<8xf32>) -> tensor<8xf32>
return %7 : tensor<8xf32>
} }
// CHECK-LABEL: func private @foo( // CHECK-LABEL: func private @foo(