mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Deduplicate functions on the one with largest number of call sites.
Instead of picking arbitrarily. PiperOrigin-RevId: 822566069
This commit is contained in:
parent
83b84b3c46
commit
39506ad1cd
|
|
@ -15,9 +15,11 @@ limitations under the License.
|
||||||
|
|
||||||
#include "xla/service/spmd/shardy/round_trip_common/export_named_computations.h"
|
#include "xla/service/spmd/shardy/round_trip_common/export_named_computations.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "absl/log/check.h"
|
#include "absl/log/check.h"
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
|
@ -151,6 +153,53 @@ class ExportNamedComputationsPass
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
SymbolTable symbolTable(moduleOp);
|
SymbolTable symbolTable(moduleOp);
|
||||||
mlir::Block& moduleBlock = moduleOp.getRegion().front();
|
mlir::Block& moduleBlock = moduleOp.getRegion().front();
|
||||||
|
|
||||||
|
if (dedupFunctionsFully) {
|
||||||
|
llvm::SmallDenseMap<ComputationKey, int64_t> funcCallSiteCounts;
|
||||||
|
llvm::SmallDenseMap<std::pair<StringRef, ManualAxesAttr>,
|
||||||
|
std::pair<NamedComputationOp, int64_t>>
|
||||||
|
funcToNamedComputations;
|
||||||
|
moduleOp.walk([&](NamedComputationOp namedComputationOp) {
|
||||||
|
ManualAxesAttr manualAxesAttr =
|
||||||
|
namedComputationOp->getAttrOfType<ManualAxesAttr>(kManualAxes);
|
||||||
|
auto key =
|
||||||
|
std::make_tuple(namedComputationOp.getName(),
|
||||||
|
namedComputationOp.getInShardings().value_or(
|
||||||
|
TensorShardingPerValueAttr()),
|
||||||
|
namedComputationOp.getOutShardings().value_or(
|
||||||
|
TensorShardingPerValueAttr()),
|
||||||
|
manualAxesAttr);
|
||||||
|
const int64_t callSiteCount = funcCallSiteCounts[key]++;
|
||||||
|
if (auto [it, inserted] = funcToNamedComputations.try_emplace(
|
||||||
|
std::pair(namedComputationOp.getName(), manualAxesAttr),
|
||||||
|
namedComputationOp, callSiteCount);
|
||||||
|
!inserted) {
|
||||||
|
auto& [cachedNamedComputationOp, cachedCallSiteCount] = it->second;
|
||||||
|
if (callSiteCount > cachedCallSiteCount) {
|
||||||
|
cachedNamedComputationOp = namedComputationOp;
|
||||||
|
cachedCallSiteCount = callSiteCount;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
for (auto& [_, namedComputationCountPair] : funcToNamedComputations) {
|
||||||
|
auto& [namedComputationOp, callSiteCount] = namedComputationCountPair;
|
||||||
|
mlir::IRRewriter rewriter(namedComputationOp);
|
||||||
|
rewriter.setInsertionPointToEnd(&moduleBlock);
|
||||||
|
ManualAxesAttr manualAxesAttr =
|
||||||
|
namedComputationOp->getAttrOfType<ManualAxesAttr>(kManualAxes);
|
||||||
|
StringAttr funcSymName =
|
||||||
|
createFuncOp(namedComputationOp, rewriter, symbolTable,
|
||||||
|
namedComputationOp.getInShardings(),
|
||||||
|
namedComputationOp.getOutShardings(), manualAxesAttr);
|
||||||
|
funcCache.try_emplace(
|
||||||
|
std::make_tuple(namedComputationOp.getName(),
|
||||||
|
TensorShardingPerValueAttr(),
|
||||||
|
TensorShardingPerValueAttr(), manualAxesAttr),
|
||||||
|
funcSymName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: The walk needs to be in post order, which is the default order, to
|
// NOTE: The walk needs to be in post order, which is the default order, to
|
||||||
// account for nested named computations.
|
// account for nested named computations.
|
||||||
moduleOp.walk([&](NamedComputationOp namedComputationOp) {
|
moduleOp.walk([&](NamedComputationOp namedComputationOp) {
|
||||||
|
|
@ -210,10 +259,10 @@ class ExportNamedComputationsPass
|
||||||
Option<bool> dedupFunctionsFully{
|
Option<bool> dedupFunctionsFully{
|
||||||
*this, "dedup-functions-fully",
|
*this, "dedup-functions-fully",
|
||||||
llvm::cl::desc(
|
llvm::cl::desc(
|
||||||
"Whether to deduplicate functions fully, regardless of the input and "
|
"If true, regardless of the input and output shardings of functions, "
|
||||||
"output shardings of functions, and it keeps one callee function for "
|
"it keeps one callee function for each caller function. The default "
|
||||||
"each caller function. The default is false, meaning it will "
|
"is false, meaning it will deduplicate only if the input and output "
|
||||||
"deduplicate only if the input and output shardings are the same."),
|
"shardings are the same."),
|
||||||
llvm::cl::init(false)};
|
llvm::cl::init(false)};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ func.func @multiple_same_named_computations_different_shardings(%arg0: tensor<8x
|
||||||
// CHECK-NEXT: %1 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
// CHECK-NEXT: %1 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||||
// CHECK-NEXT: return %1 : tensor<8x2xi32>
|
// CHECK-NEXT: return %1 : tensor<8x2xi32>
|
||||||
%0 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>] (%arg1: tensor<8x2xi32>) {
|
%0 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>] (%arg1: tensor<8x2xi32>) {
|
||||||
%2 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32>
|
%2 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>} : tensor<8x2xi32>
|
||||||
sdy.return %2 : tensor<8x2xi32>
|
sdy.return %2 : tensor<8x2xi32>
|
||||||
} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||||
%1 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {"y"}]>] (%arg1: tensor<8x2xi32>) {
|
%1 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {"y"}]>] (%arg1: tensor<8x2xi32>) {
|
||||||
|
|
@ -21,6 +21,39 @@ func.func @multiple_same_named_computations_different_shardings(%arg0: tensor<8x
|
||||||
// CHECK-LABEL: func private @baz(
|
// CHECK-LABEL: func private @baz(
|
||||||
// CHECK-SAME: %arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>})
|
// CHECK-SAME: %arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>})
|
||||||
// CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>})
|
// CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>})
|
||||||
|
// CHECK-NEXT: stablehlo.multiply %arg0, %arg0
|
||||||
|
// CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
sdy.mesh @mesh = <["x"=2, "y"=2]>
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @multiple_same_named_computations_different_shardings_different_number_of_call_sites(
|
||||||
|
func.func @multiple_same_named_computations_different_shardings_different_number_of_call_sites(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
|
||||||
|
// CHECK-NEXT: %0 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||||
|
// CHECK-NEXT: %1 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||||
|
// CHECK-NEXT: %2 = call @baz(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||||
|
// CHECK-NEXT: return %2 : tensor<8x2xi32>
|
||||||
|
%0 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>] (%arg1: tensor<8x2xi32>) {
|
||||||
|
%2 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {?}]>]>} : tensor<8x2xi32>
|
||||||
|
sdy.return %2 : tensor<8x2xi32>
|
||||||
|
} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||||
|
%1 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {"y"}]>] (%arg1: tensor<8x2xi32>) {
|
||||||
|
%3 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32>
|
||||||
|
sdy.return %3 : tensor<8x2xi32>
|
||||||
|
} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||||
|
%2 = sdy.named_computation<"baz">(%0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {"y"}]>] (%arg1: tensor<8x2xi32>) {
|
||||||
|
%3 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32>
|
||||||
|
sdy.return %3 : tensor<8x2xi32>
|
||||||
|
} : (tensor<8x2xi32>) -> tensor<8x2xi32>
|
||||||
|
return %2 : tensor<8x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func private @baz(
|
||||||
|
// CHECK-SAME: %arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>})
|
||||||
|
// CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>})
|
||||||
|
// CHECK-NEXT: stablehlo.multiply %arg0, %arg0
|
||||||
|
// CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
|
@ -29,13 +62,17 @@ sdy.mesh @mesh = <["x"=2, "y"=2]>
|
||||||
// CHECK-LABEL: func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes(
|
// CHECK-LABEL: func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes(
|
||||||
// CHECK-SAME: %arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}
|
// CHECK-SAME: %arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}
|
||||||
// CHECK-SAME: -> (tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) {
|
// CHECK-SAME: -> (tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) {
|
||||||
// CHECK: %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
|
// CHECK-NEXT: %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
|
||||||
// CHECK-NEXT: %2 = func.call @foo(%arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
|
// CHECK-NEXT: %3 = func.call @foo(%arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
|
||||||
// CHECK-NEXT: %3 = func.call @foo(%2) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
|
// CHECK-NEXT: %4 = func.call @foo(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
|
||||||
// CHECK-NEXT: sdy.return %3 : tensor<4xf32>
|
// CHECK-NEXT: sdy.return %4 : tensor<4xf32>
|
||||||
// CHECK-NEXT: } : (tensor<8xf32>) -> tensor<8xf32>
|
// CHECK-NEXT: } : (tensor<8xf32>) -> tensor<8xf32>
|
||||||
// CHECK-NEXT: %1 = call @foo_0(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : (tensor<8xf32>) -> tensor<8xf32>
|
// CHECK-NEXT: %1 = call @foo_0(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : (tensor<8xf32>) -> tensor<8xf32>
|
||||||
// CHECK-NEXT: return %1 : tensor<8xf32>
|
// CHECK-NEXT: %2 = sdy.manual_computation(%1) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
|
||||||
|
// CHECK-NEXT: %3 = func.call @foo(%arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>, xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
|
||||||
|
// CHECK-NEXT: sdy.return %3 : tensor<4xf32>
|
||||||
|
// CHECK-NEXT: } : (tensor<8xf32>) -> tensor<8xf32>
|
||||||
|
// CHECK-NEXT: return %2 : tensor<8xf32>
|
||||||
func.func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes(%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) -> (tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) {
|
func.func @named_computations_same_funcs_two_same_manual_axes_different_shardings_one_without_manual_axes(%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) -> (tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}]>}) {
|
||||||
%0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
|
%0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
|
||||||
%1 = sdy.named_computation<"foo">(%arg1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg2: tensor<4xf32>) {
|
%1 = sdy.named_computation<"foo">(%arg1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg2: tensor<4xf32>) {
|
||||||
|
|
@ -52,7 +89,14 @@ func.func @named_computations_same_funcs_two_same_manual_axes_different_sharding
|
||||||
%6 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<8xf32>
|
%6 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<8xf32>
|
||||||
sdy.return %6 : tensor<8xf32>
|
sdy.return %6 : tensor<8xf32>
|
||||||
} : (tensor<8xf32>) -> tensor<8xf32>
|
} : (tensor<8xf32>) -> tensor<8xf32>
|
||||||
return %5 : tensor<8xf32>
|
%7 = sdy.manual_computation(%5) in_shardings=[<@mesh, [{"x", "y"}]>] out_shardings=[<@mesh, [{"x", "y"}]>] manual_axes={"x"} (%arg1: tensor<4xf32>) {
|
||||||
|
%1 = sdy.named_computation<"foo">(%arg1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg2: tensor<4xf32>) {
|
||||||
|
%2 = stablehlo.abs %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<4xf32>
|
||||||
|
sdy.return %2 : tensor<4xf32>
|
||||||
|
} {xla.sdy.manual_axes = #sdy<manual_axes{"x"}>} : (tensor<4xf32>) -> tensor<4xf32>
|
||||||
|
sdy.return %1 : tensor<4xf32>
|
||||||
|
} : (tensor<8xf32>) -> tensor<8xf32>
|
||||||
|
return %7 : tensor<8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func private @foo(
|
// CHECK-LABEL: func private @foo(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user