Refactor spmd partitioner.

PiperOrigin-RevId: 822689391
This commit is contained in:
Zixuan Jiang 2025-10-22 12:13:58 -07:00 committed by TensorFlower Gardener
parent 1b08f96abf
commit 4d53eda2fe
5 changed files with 33 additions and 53 deletions

View File

@ -89,7 +89,6 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/utility",
"@local_tsl//tsl/platform:numbers",
],
)

View File

@ -2467,10 +2467,9 @@ SpmdPartitioningVisitor::MakePartitioningState() {
state.collective_ops_creator = *visiting_collective_ops_creator_;
state.partition_id = *visiting_partition_id_;
return CreatePerGroupPartitioningState(state, *device_groups_, &b_);
} else {
}
state.collective_ops_creator = collective_ops_creator_;
state.partition_id = partition_id_;
}
return state;
}
@ -3323,7 +3322,8 @@ absl::Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
output_shard_size * split_factor);
return operand.state().b->AddInstruction(HloInstruction::CreateReshape(
output_shard_shape, reshard_operand->sharded_input));
} else if (output_dim_size % input_dim_size == 0) {
}
if (output_dim_size % input_dim_size == 0) {
// Merge dims.
int64_t merge_factor = output_dim_size / input_dim_size;
// First reshape locally. (The sharded dimension could include padded
@ -5054,7 +5054,7 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions,
// If the src/dst pairs are empty, then the collective permute
// just initializes the output to zero.
return CreateZero(operand->shape(), b);
} else {
}
// A collective-permute is a copy if all pairs are "identity" and
// all partitions are listed.
bool is_copy =
@ -5065,11 +5065,9 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions,
});
if (is_copy) {
return operand;
} else {
}
return b->AddInstruction(HloInstruction::CreateCollectivePermute(
operand->shape(), operand, src_dst_pairs, channel_id));
}
}
},
[create_all_to_all_list_of_lists](
SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
@ -5309,13 +5307,13 @@ HloInstruction* SpmdPartitioner::AllReduceAlongShardingDimsInternal(
.create_cross_partition_all_reduce_with_iota_device_list(
b, operand, reduction, partition_group_list.value(),
(*next_channel_id)++);
} else {
}
auto partition_subgroups =
GetPartitionGroupsForReplication(sharding, selected_dims);
return collectives_creator.create_cross_partition_all_reduce(
b, operand, reduction, partition_subgroups, (*next_channel_id)++);
}
}
auto result = operand;
for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
if (sharding.tile_assignment().dim(*it) == 1) {

View File

@ -158,8 +158,6 @@ class SpmdBuilder : public HloComputation::Builder {
instructions_[hlo];
}
HloInstruction* visiting_hlo() const { return visiting_hlo_; }
// Wrapper of queries to broadcast_dims_.
std::optional<const absl::flat_hash_set<int64_t>*> BroadcastDimsForCreatedHlo(
const HloInstruction* hlo) {
@ -370,7 +368,7 @@ class SpmdPartitioner : public HloModulePass {
}
// Update module's parameter and output sharding information, based on the
// sharding information of the module's parameters and outptuts.
// sharding information of the module's parameters and outputs.
static void RecordInputsOutputsSharding(HloModule* module);
int64_t num_partitions() const { return num_partitions_; }
@ -443,7 +441,6 @@ class SpmdPartitioner : public HloModulePass {
SpmdPartitionerOptions options_;
SPMDCollectiveOpsCreator collective_ops_creator_;
std::vector<std::vector<int64_t>> device_groups_;
absl::flat_hash_set<absl::string_view> execution_threads_;
};
@ -722,6 +719,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
absl::Status DefaultAction(HloInstruction* hlo) override;
// go/keep-sorted start
absl::Status HandleAllReduce(HloInstruction* hlo) override;
absl::Status HandleBitcastConvert(HloInstruction* hlo) override;
absl::Status HandleBroadcast(HloInstruction* hlo) override;
@ -760,6 +758,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
absl::Status HandleTriangularSolve(HloInstruction* hlo) override;
absl::Status HandleTuple(HloInstruction* hlo) override;
absl::Status HandleWhile(HloInstruction* hlo) override;
// go/keep-sorted end
// Implementation of dot partitioning given DotGeneralDimsMapping.
template <typename CreateShardedFunctor>

View File

@ -2534,10 +2534,9 @@ HloSharding CreateMatchingShardingOnDims(
if (to_be_partially_replicated) {
return AlignShardingOnDims(HloSharding::PartialTile(tgt_tile_assignment),
target_dims, source_sharding, source_dims);
} else {
}
return AlignShardingOnDims(HloSharding::Tile(tgt_tile_assignment),
target_dims, source_sharding, source_dims);
}
}
std::optional<GatherScatterParallelDimSharding>
@ -2909,8 +2908,8 @@ std::vector<std::vector<int64_t>> GetPartitionGroupsAcrossTargetDims(
[&](absl::Span<const int64_t> indices, int64_t device) {
int64_t group_id = 0;
for (int64_t dim = 0; dim < indices.size(); ++dim) {
auto it = absl::c_find(target_dims, dim);
if (it != target_dims.end()) {
if (auto it = absl::c_find(target_dims, dim);
it != target_dims.end()) {
int64_t group_size =
group_sizes[std::distance(target_dims.begin(), it)];
group_id *= sharding.tile_assignment().dim(dim) / group_size;
@ -2963,8 +2962,7 @@ std::optional<IotaReplicaGroupList> GetIotaPartitionGroupsAcrossTargetDims(
std::vector<int64_t> target_dim_locations;
for (int64_t dim = 0; dim < sharding.tile_assignment().num_dimensions();
++dim) {
auto it = std::find(target_dims.begin(), target_dims.end(), dim);
if (it != target_dims.end()) {
if (auto it = absl::c_find(target_dims, dim); it != target_dims.end()) {
int64_t current_val = sharding.tile_assignment().dim(dim);
int64_t group_size = group_sizes[std::distance(target_dims.begin(), it)];
reshape_dimensions.push_back(current_val / group_size);
@ -2978,8 +2976,8 @@ std::optional<IotaReplicaGroupList> GetIotaPartitionGroupsAcrossTargetDims(
std::vector<int> transpose_dims(reshape_dimensions.size());
std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
for (int64_t loc : target_dim_locations) {
auto it = std::find(transpose_dims.begin(), transpose_dims.end(), loc);
if (it != transpose_dims.end()) {
if (auto it = absl::c_find(transpose_dims, loc);
it != transpose_dims.end()) {
transpose_dims.erase(it);
transpose_dims.push_back(loc);
}
@ -3047,8 +3045,7 @@ std::optional<IotaReplicaGroupList> GetIotaPartitionGroupsForReplication(
replication_dims.end());
std::sort(replication_dims_sorted.begin(), replication_dims_sorted.end());
for (int64_t i : replication_dims_sorted) {
auto it = std::find(transpose_dims.begin(), transpose_dims.end(), i);
if (it != transpose_dims.end()) {
if (auto it = absl::c_find(transpose_dims, i); it != transpose_dims.end()) {
transpose_dims.erase(it);
transpose_dims.push_back(i);
}

View File

@ -17,29 +17,21 @@ limitations under the License.
#define XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <initializer_list>
#include <limits>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_replace.h"
#include "absl/types/span.h"
#include "absl/utility/utility.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
@ -47,7 +39,6 @@ limitations under the License.
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/hlo/ir/replica_group.h"
#include "xla/hlo/transforms/simplifiers/hlo_dce.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/hlo/utils/hlo_sharding_util.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
@ -55,7 +46,6 @@ limitations under the License.
#include "xla/service/spmd/spmd_partitioner.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
@ -822,9 +812,6 @@ template <typename Arg, IsHloModulePointer<Arg> = 0>
std::decay_t<Arg> FakeHloModule(Arg&& module, HloModule* fake_module) {
return fake_module;
}
template <class T>
using decay_rvalue_reference_t =
std::conditional_t<std::is_rvalue_reference<T>::value, std::decay_t<T>, T>;
// Modifies SpmdPartitioningVisitor* type objects.
template <typename Arg, IsSpmdPartitioningVisitorPointer<Arg> = 0>