mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Refactor spmd partitioner.
PiperOrigin-RevId: 822689391
This commit is contained in:
parent
1b08f96abf
commit
4d53eda2fe
1
third_party/xla/xla/service/spmd/BUILD
vendored
1
third_party/xla/xla/service/spmd/BUILD
vendored
|
|
@ -89,7 +89,6 @@ cc_library(
|
|||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_absl//absl/utility",
|
||||
"@local_tsl//tsl/platform:numbers",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2467,10 +2467,9 @@ SpmdPartitioningVisitor::MakePartitioningState() {
|
|||
state.collective_ops_creator = *visiting_collective_ops_creator_;
|
||||
state.partition_id = *visiting_partition_id_;
|
||||
return CreatePerGroupPartitioningState(state, *device_groups_, &b_);
|
||||
} else {
|
||||
state.collective_ops_creator = collective_ops_creator_;
|
||||
state.partition_id = partition_id_;
|
||||
}
|
||||
state.collective_ops_creator = collective_ops_creator_;
|
||||
state.partition_id = partition_id_;
|
||||
return state;
|
||||
}
|
||||
|
||||
|
|
@ -3323,7 +3322,8 @@ absl::Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
|
|||
output_shard_size * split_factor);
|
||||
return operand.state().b->AddInstruction(HloInstruction::CreateReshape(
|
||||
output_shard_shape, reshard_operand->sharded_input));
|
||||
} else if (output_dim_size % input_dim_size == 0) {
|
||||
}
|
||||
if (output_dim_size % input_dim_size == 0) {
|
||||
// Merge dims.
|
||||
int64_t merge_factor = output_dim_size / input_dim_size;
|
||||
// First reshape locally. (The sharded dimension could include padded
|
||||
|
|
@ -5054,22 +5054,20 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions,
|
|||
// If the src/dst pairs are empty, then the collective permute
|
||||
// just initializes the output to zero.
|
||||
return CreateZero(operand->shape(), b);
|
||||
} else {
|
||||
// A collective-permute is a copy if all pairs are "identity" and
|
||||
// all partitions are listed.
|
||||
bool is_copy =
|
||||
src_dst_pairs.size() == num_partitions &&
|
||||
absl::c_all_of(src_dst_pairs,
|
||||
[](const std::pair<int64_t, int64_t>& pair) {
|
||||
return pair.first == pair.second;
|
||||
});
|
||||
if (is_copy) {
|
||||
return operand;
|
||||
} else {
|
||||
return b->AddInstruction(HloInstruction::CreateCollectivePermute(
|
||||
operand->shape(), operand, src_dst_pairs, channel_id));
|
||||
}
|
||||
}
|
||||
// A collective-permute is a copy if all pairs are "identity" and
|
||||
// all partitions are listed.
|
||||
bool is_copy =
|
||||
src_dst_pairs.size() == num_partitions &&
|
||||
absl::c_all_of(src_dst_pairs,
|
||||
[](const std::pair<int64_t, int64_t>& pair) {
|
||||
return pair.first == pair.second;
|
||||
});
|
||||
if (is_copy) {
|
||||
return operand;
|
||||
}
|
||||
return b->AddInstruction(HloInstruction::CreateCollectivePermute(
|
||||
operand->shape(), operand, src_dst_pairs, channel_id));
|
||||
},
|
||||
[create_all_to_all_list_of_lists](
|
||||
SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
|
||||
|
|
@ -5309,13 +5307,13 @@ HloInstruction* SpmdPartitioner::AllReduceAlongShardingDimsInternal(
|
|||
.create_cross_partition_all_reduce_with_iota_device_list(
|
||||
b, operand, reduction, partition_group_list.value(),
|
||||
(*next_channel_id)++);
|
||||
} else {
|
||||
auto partition_subgroups =
|
||||
GetPartitionGroupsForReplication(sharding, selected_dims);
|
||||
return collectives_creator.create_cross_partition_all_reduce(
|
||||
b, operand, reduction, partition_subgroups, (*next_channel_id)++);
|
||||
}
|
||||
auto partition_subgroups =
|
||||
GetPartitionGroupsForReplication(sharding, selected_dims);
|
||||
return collectives_creator.create_cross_partition_all_reduce(
|
||||
b, operand, reduction, partition_subgroups, (*next_channel_id)++);
|
||||
}
|
||||
|
||||
auto result = operand;
|
||||
for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
|
||||
if (sharding.tile_assignment().dim(*it) == 1) {
|
||||
|
|
|
|||
|
|
@ -158,8 +158,6 @@ class SpmdBuilder : public HloComputation::Builder {
|
|||
instructions_[hlo];
|
||||
}
|
||||
|
||||
HloInstruction* visiting_hlo() const { return visiting_hlo_; }
|
||||
|
||||
// Wrapper of queries to broadcast_dims_.
|
||||
std::optional<const absl::flat_hash_set<int64_t>*> BroadcastDimsForCreatedHlo(
|
||||
const HloInstruction* hlo) {
|
||||
|
|
@ -370,7 +368,7 @@ class SpmdPartitioner : public HloModulePass {
|
|||
}
|
||||
|
||||
// Update module's parameter and output sharding information, based on the
|
||||
// sharding information of the module's parameters and outptuts.
|
||||
// sharding information of the module's parameters and outputs.
|
||||
static void RecordInputsOutputsSharding(HloModule* module);
|
||||
|
||||
int64_t num_partitions() const { return num_partitions_; }
|
||||
|
|
@ -443,7 +441,6 @@ class SpmdPartitioner : public HloModulePass {
|
|||
|
||||
SpmdPartitionerOptions options_;
|
||||
SPMDCollectiveOpsCreator collective_ops_creator_;
|
||||
std::vector<std::vector<int64_t>> device_groups_;
|
||||
absl::flat_hash_set<absl::string_view> execution_threads_;
|
||||
};
|
||||
|
||||
|
|
@ -722,6 +719,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
|
|||
|
||||
absl::Status DefaultAction(HloInstruction* hlo) override;
|
||||
|
||||
// go/keep-sorted start
|
||||
absl::Status HandleAllReduce(HloInstruction* hlo) override;
|
||||
absl::Status HandleBitcastConvert(HloInstruction* hlo) override;
|
||||
absl::Status HandleBroadcast(HloInstruction* hlo) override;
|
||||
|
|
@ -760,6 +758,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
|
|||
absl::Status HandleTriangularSolve(HloInstruction* hlo) override;
|
||||
absl::Status HandleTuple(HloInstruction* hlo) override;
|
||||
absl::Status HandleWhile(HloInstruction* hlo) override;
|
||||
// go/keep-sorted end
|
||||
|
||||
// Implementation of dot partitioning given DotGeneralDimsMapping.
|
||||
template <typename CreateShardedFunctor>
|
||||
|
|
|
|||
|
|
@ -2534,10 +2534,9 @@ HloSharding CreateMatchingShardingOnDims(
|
|||
if (to_be_partially_replicated) {
|
||||
return AlignShardingOnDims(HloSharding::PartialTile(tgt_tile_assignment),
|
||||
target_dims, source_sharding, source_dims);
|
||||
} else {
|
||||
return AlignShardingOnDims(HloSharding::Tile(tgt_tile_assignment),
|
||||
target_dims, source_sharding, source_dims);
|
||||
}
|
||||
return AlignShardingOnDims(HloSharding::Tile(tgt_tile_assignment),
|
||||
target_dims, source_sharding, source_dims);
|
||||
}
|
||||
|
||||
std::optional<GatherScatterParallelDimSharding>
|
||||
|
|
@ -2909,8 +2908,8 @@ std::vector<std::vector<int64_t>> GetPartitionGroupsAcrossTargetDims(
|
|||
[&](absl::Span<const int64_t> indices, int64_t device) {
|
||||
int64_t group_id = 0;
|
||||
for (int64_t dim = 0; dim < indices.size(); ++dim) {
|
||||
auto it = absl::c_find(target_dims, dim);
|
||||
if (it != target_dims.end()) {
|
||||
if (auto it = absl::c_find(target_dims, dim);
|
||||
it != target_dims.end()) {
|
||||
int64_t group_size =
|
||||
group_sizes[std::distance(target_dims.begin(), it)];
|
||||
group_id *= sharding.tile_assignment().dim(dim) / group_size;
|
||||
|
|
@ -2963,8 +2962,7 @@ std::optional<IotaReplicaGroupList> GetIotaPartitionGroupsAcrossTargetDims(
|
|||
std::vector<int64_t> target_dim_locations;
|
||||
for (int64_t dim = 0; dim < sharding.tile_assignment().num_dimensions();
|
||||
++dim) {
|
||||
auto it = std::find(target_dims.begin(), target_dims.end(), dim);
|
||||
if (it != target_dims.end()) {
|
||||
if (auto it = absl::c_find(target_dims, dim); it != target_dims.end()) {
|
||||
int64_t current_val = sharding.tile_assignment().dim(dim);
|
||||
int64_t group_size = group_sizes[std::distance(target_dims.begin(), it)];
|
||||
reshape_dimensions.push_back(current_val / group_size);
|
||||
|
|
@ -2978,8 +2976,8 @@ std::optional<IotaReplicaGroupList> GetIotaPartitionGroupsAcrossTargetDims(
|
|||
std::vector<int> transpose_dims(reshape_dimensions.size());
|
||||
std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
|
||||
for (int64_t loc : target_dim_locations) {
|
||||
auto it = std::find(transpose_dims.begin(), transpose_dims.end(), loc);
|
||||
if (it != transpose_dims.end()) {
|
||||
if (auto it = absl::c_find(transpose_dims, loc);
|
||||
it != transpose_dims.end()) {
|
||||
transpose_dims.erase(it);
|
||||
transpose_dims.push_back(loc);
|
||||
}
|
||||
|
|
@ -3047,8 +3045,7 @@ std::optional<IotaReplicaGroupList> GetIotaPartitionGroupsForReplication(
|
|||
replication_dims.end());
|
||||
std::sort(replication_dims_sorted.begin(), replication_dims_sorted.end());
|
||||
for (int64_t i : replication_dims_sorted) {
|
||||
auto it = std::find(transpose_dims.begin(), transpose_dims.end(), i);
|
||||
if (it != transpose_dims.end()) {
|
||||
if (auto it = absl::c_find(transpose_dims, i); it != transpose_dims.end()) {
|
||||
transpose_dims.erase(it);
|
||||
transpose_dims.push_back(i);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,29 +17,21 @@ limitations under the License.
|
|||
#define XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <initializer_list>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_replace.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "absl/utility/utility.h"
|
||||
#include "xla/hlo/ir/hlo_casting_utils.h"
|
||||
#include "xla/hlo/ir/hlo_computation.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_module.h"
|
||||
|
|
@ -47,7 +39,6 @@ limitations under the License.
|
|||
#include "xla/hlo/ir/hlo_sharding.h"
|
||||
#include "xla/hlo/ir/replica_group.h"
|
||||
#include "xla/hlo/transforms/simplifiers/hlo_dce.h"
|
||||
#include "xla/hlo/utils/hlo_query.h"
|
||||
#include "xla/hlo/utils/hlo_sharding_util.h"
|
||||
#include "xla/literal.h"
|
||||
#include "xla/literal_util.h"
|
||||
|
|
@ -55,7 +46,6 @@ limitations under the License.
|
|||
#include "xla/service/spmd/spmd_partitioner.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/tsl/platform/errors.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/util.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
|
@ -822,9 +812,6 @@ template <typename Arg, IsHloModulePointer<Arg> = 0>
|
|||
std::decay_t<Arg> FakeHloModule(Arg&& module, HloModule* fake_module) {
|
||||
return fake_module;
|
||||
}
|
||||
template <class T>
|
||||
using decay_rvalue_reference_t =
|
||||
std::conditional_t<std::is_rvalue_reference<T>::value, std::decay_t<T>, T>;
|
||||
|
||||
// Modifies SpmdPartitioningVisitor* type objects.
|
||||
template <typename Arg, IsSpmdPartitioningVisitorPointer<Arg> = 0>
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user