mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Deduplicate functions on the one with largest number of call sites.
Instead of picking arbitrarily. PiperOrigin-RevId: 822566069
This commit is contained in:
parent
83b84b3c46
commit
39506ad1cd
|
|
@ -15,9 +15,11 @@ limitations under the License.
|
|||
|
||||
#include "xla/service/spmd/shardy/round_trip_common/export_named_computations.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
|
@ -151,6 +153,53 @@ class ExportNamedComputationsPass
|
|||
ModuleOp moduleOp = getOperation();
|
||||
SymbolTable symbolTable(moduleOp);
|
||||
mlir::Block& moduleBlock = moduleOp.getRegion().front();
|
||||
|
||||
if (dedupFunctionsFully) {
|
||||
llvm::SmallDenseMap<ComputationKey, int64_t> funcCallSiteCounts;
|
||||
llvm::SmallDenseMap<std::pair<StringRef, ManualAxesAttr>,
|
||||
std::pair<NamedComputationOp, int64_t>>
|
||||
funcToNamedComputations;
|
||||
moduleOp.walk([&](NamedComputationOp namedComputationOp) {
|
||||
ManualAxesAttr manualAxesAttr =
|
||||
namedComputationOp->getAttrOfType<ManualAxesAttr>(kManualAxes);
|
||||
auto key =
|
||||
std::make_tuple(namedComputationOp.getName(),
|
||||
namedComputationOp.getInShardings().value_or(
|
||||
TensorShardingPerValueAttr()),
|
||||
namedComputationOp.getOutShardings().value_or(
|
||||
TensorShardingPerValueAttr()),
|
||||
manualAxesAttr);
|
||||
const int64_t callSiteCount = funcCallSiteCounts[key]++;
|
||||
if (auto [it, inserted] = funcToNamedComputations.try_emplace(
|
||||
std::pair(namedComputationOp.getName(), manualAxesAttr),
|
||||
namedComputationOp, callSiteCount);
|
||||
!inserted) {
|
||||
auto& [cachedNamedComputationOp, cachedCallSiteCount] = it->second;
|
||||
if (callSiteCount > cachedCallSiteCount) {
|
||||
cachedNamedComputationOp = namedComputationOp;
|
||||
cachedCallSiteCount = callSiteCount;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
for (auto& [_, namedComputationCountPair] : funcToNamedComputations) {
|
||||
auto& [namedComputationOp, callSiteCount] = namedComputationCountPair;
|
||||
mlir::IRRewriter rewriter(namedComputationOp);
|
||||
rewriter.setInsertionPointToEnd(&moduleBlock);
|
||||
ManualAxesAttr manualAxesAttr =
|
||||
namedComputationOp->getAttrOfType<ManualAxesAttr>(kManualAxes);
|
||||
StringAttr funcSymName =
|
||||
createFuncOp(namedComputationOp, rewriter, symbolTable,
|
||||
namedComputationOp.getInShardings(),
|
||||
namedComputationOp.getOutShardings(), manualAxesAttr);
|
||||
funcCache.try_emplace(
|
||||
std::make_tuple(namedComputationOp.getName(),
|
||||
TensorShardingPerValueAttr(),
|
||||
TensorShardingPerValueAttr(), manualAxesAttr),
|
||||
funcSymName);
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: The walk needs to be in post order, which is the default order, to
|
||||
// account for nested named computations.
|
||||
moduleOp.walk([&](NamedComputationOp namedComputationOp) {
|
||||
|
|
@ -210,10 +259,10 @@ class ExportNamedComputationsPass
|
|||
Option<bool> dedupFunctionsFully{
|
||||
*this, "dedup-functions-fully",
|
||||
llvm::cl::desc(
|
||||
"Whether to deduplicate functions fully, regardless of the input and "
|
||||
"output shardings of functions, and it keeps one callee function for "
|
||||
"each caller function. The default is false, meaning it will "
|
||||
"deduplicate only if the input and output shardings are the same."),
|
||||
"If true, regardless of the input and output shardings of functions, "
|
||||
"it keeps one callee function for each caller function. The default "
|
||||
"is false, meaning it will deduplicate only if the input and output "
|
||||
"shardings are the same."),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ func.func @multiple_same_named_computations_different_shardings(%arg0: tensor<8x
|
|||
// CHECK-NEXT: %1 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||
// CHECK-NEXT: return %1 : tensor<8x2xi32>
|
||||
%0 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>] (%arg1: tensor<8x2xi32>) {
|
||||
%2 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32>
|
||||
%2 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>} : tensor<8x2xi32>
|
||||
sdy.return %2 : tensor<8x2xi32>
|
||||
} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||
%1 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {"y"}]>] (%arg1: tensor<8x2xi32>) {
|
||||
|
|
@ -21,6 +21,39 @@ func.func @multiple_same_named_computations_different_shardings(%arg0: tensor<8x
|
|||
// CHECK-LABEL: func private @baz(
|
||||
// CHECK-SAME: %arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>})
|
||||
// CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>})
|
||||
// CHECK-NEXT: stablehlo.multiply %arg0, %arg0
|
||||
// CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>}
|
||||
|
||||
// -----
|
||||
|
||||
sdy.mesh @mesh = <["x"=2, "y"=2]>
|
||||
|
||||
// CHECK-LABEL: func @multiple_same_named_computations_different_shardings_different_number_of_call_sites(
|
||||
func.func @multiple_same_named_computations_different_shardings_different_number_of_call_sites(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
|
||||
// CHECK-NEXT: %0 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||
// CHECK-NEXT: %1 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||
// CHECK-NEXT: %2 = call @baz(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||
// CHECK-NEXT: return %2 : tensor<8x2xi32>
|
||||
%0 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>] (%arg1: tensor<8x2xi32>) {
|
||||
%2 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>} : tensor<8x2xi32>
|
||||
sdy.return %2 : tensor<8x2xi32>
|
||||
} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||
%1 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {"y"}]>] (%arg1: tensor<8x2xi32>) {
|
||||
%3 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32>
|
||||
sdy.return %3 : tensor<8x2xi32>
|
||||
} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||
%2 = sdy.named_computation<"baz">(%0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {"y"}]>] (%arg1: tensor<8x2xi32>) {
|
||||
%3 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32>
|
||||
sdy.return %3 : tensor<8x2xi32>
|
||||
} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||
return %2 : tensor<8x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func private @baz(
|
||||
// CHECK-SAME: %arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>})
|
||||
// CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>})
|
||||
// CHECK-NEXT: stablehlo.multiply %arg0, %arg0
|
||||
// CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>}
|
||||
|
||||
// -----
|
||||
|
||||
|
|
@ -29,13 +62,17 @@ sdy.mesh @mesh = <["x"=2, "y"=2]>
|
|||
// CHECK-LABEL: func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes(
|
||||
// CHECK-SAME: %arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}
|
||||
// CHECK-SAME: -> (tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) {
|
||||
// CHECK: %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
|
||||
// CHECK-NEXT: %2 = func.call @foo(%arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: %3 = func.call @foo(%2) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: sdy.return %3 : tensor<4xf32>
|
||||
// CHECK-NEXT: %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
|
||||
// CHECK-NEXT: %3 = func.call @foo(%arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: %4 = func.call @foo(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: sdy.return %4 : tensor<4xf32>
|
||||
// CHECK-NEXT: } : (tensor<8xf32>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: %1 = call @foo_0(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : (tensor<8xf32>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: return %1 : tensor<8xf32>
|
||||
// CHECK-NEXT: %2 = sdy.manual_computation(%1) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
|
||||
// CHECK-NEXT: %3 = func.call @foo(%arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: sdy.return %3 : tensor<4xf32>
|
||||
// CHECK-NEXT: } : (tensor<8xf32>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: return %2 : tensor<8xf32>
|
||||
func.func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes(%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) -> (tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) {
|
||||
%0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
|
||||
%1 = sdy.named_computation<"foo">(%arg1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg2: tensor<4xf32>) {
|
||||
|
|
@ -52,7 +89,14 @@ func.func @named_computations_same_funcs_two_same_manual_axes_different_sharding
|
|||
%6 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<8xf32>
|
||||
sdy.return %6 : tensor<8xf32>
|
||||
} : (tensor<8xf32>) -> tensor<8xf32>
|
||||
return %5 : tensor<8xf32>
|
||||
%7 = sdy.manual_computation(%5) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
|
||||
%1 = sdy.named_computation<"foo">(%arg1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg2: tensor<4xf32>) {
|
||||
%2 = stablehlo.abs %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<4xf32>
|
||||
sdy.return %2 : tensor<4xf32>
|
||||
} {xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
sdy.return %1 : tensor<4xf32>
|
||||
} : (tensor<8xf32>) -> tensor<8xf32>
|
||||
return %7 : tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func private @foo(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user