mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:SPMD] Allows backend to specify whether to use multistep allgather/allreduce
PiperOrigin-RevId: 356773102 Change-Id: Ib66f85ccf7fe8a76c31dc7f6322f3f6e28488d97
This commit is contained in:
parent
1ccefa8d5e
commit
b3cba5a7aa
|
|
@ -3500,6 +3500,15 @@ HloInstruction* SpmdPartitioner::AllGatherShards(
|
|||
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
|
||||
int64* next_channel_id, absl::Span<const int64> selected_dims,
|
||||
const SPMDCollectiveOpsCreator& collectives_creator) {
|
||||
return AllGatherShardsInternal(b, operand, sharding, next_channel_id,
|
||||
selected_dims, collectives_creator,
|
||||
/*per_dim_ag=*/true);
|
||||
}
|
||||
|
||||
HloInstruction* SpmdPartitioner::AllGatherShardsInternal(
|
||||
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
|
||||
int64* next_channel_id, absl::Span<const int64> selected_dims,
|
||||
const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag) {
|
||||
if (selected_dims.empty()) {
|
||||
return operand;
|
||||
}
|
||||
|
|
@ -3513,16 +3522,28 @@ HloInstruction* SpmdPartitioner::AllGatherShards(
|
|||
auto reshape = b->AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand));
|
||||
HloInstruction* result = reshape;
|
||||
for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
|
||||
if (sharding.tile_assignment().dim(*it) == 1) {
|
||||
continue;
|
||||
if (per_dim_ag) {
|
||||
for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
|
||||
if (sharding.tile_assignment().dim(*it) == 1) {
|
||||
continue;
|
||||
}
|
||||
auto partition_subgroups =
|
||||
GetPartitionGroupsForReplication(sharding, {*it});
|
||||
shape[0] *= partition_subgroups[0].size();
|
||||
result = collectives_creator.create_cross_partition_all_gather(
|
||||
b, result,
|
||||
ShapeUtil::MakeShape(operand->shape().element_type(), shape),
|
||||
partition_subgroups, (*next_channel_id)++,
|
||||
/*all_gather_dimension=*/0);
|
||||
}
|
||||
} else {
|
||||
auto partition_subgroups =
|
||||
GetPartitionGroupsForReplication(sharding, {*it});
|
||||
GetPartitionGroupsForReplication(sharding, selected_dims);
|
||||
shape[0] *= partition_subgroups[0].size();
|
||||
result = collectives_creator.create_cross_partition_all_gather(
|
||||
b, result, ShapeUtil::MakeShape(operand->shape().element_type(), shape),
|
||||
partition_subgroups, (*next_channel_id)++, /*all_gather_dimension=*/0);
|
||||
partition_subgroups, (*next_channel_id)++,
|
||||
/*all_gather_dimension=*/0);
|
||||
}
|
||||
// If n > 1 dimensions are partitioned, split the leading dimension to n.
|
||||
std::vector<int64> tiled_dims;
|
||||
|
|
@ -3579,6 +3600,22 @@ HloInstruction* SpmdPartitioner::AllReduceAlongShardingDims(
|
|||
int64* next_channel_id, absl::Span<const int64> selected_dims,
|
||||
const SPMDCollectiveOpsCreator& collectives_creator,
|
||||
HloComputation* reduction) {
|
||||
return AllReduceAlongShardingDimsInternal(
|
||||
b, operand, sharding, next_channel_id, selected_dims, collectives_creator,
|
||||
reduction, /*per_dim_ar=*/true);
|
||||
}
|
||||
|
||||
HloInstruction* SpmdPartitioner::AllReduceAlongShardingDimsInternal(
|
||||
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
|
||||
int64* next_channel_id, absl::Span<const int64> selected_dims,
|
||||
const SPMDCollectiveOpsCreator& collectives_creator,
|
||||
HloComputation* reduction, bool per_dim_ar) {
|
||||
if (!per_dim_ar) {
|
||||
auto partition_subgroups =
|
||||
GetPartitionGroupsForReplication(sharding, selected_dims);
|
||||
return collectives_creator.create_cross_partition_all_reduce(
|
||||
b, operand, reduction, partition_subgroups, (*next_channel_id)++);
|
||||
}
|
||||
auto result = operand;
|
||||
for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
|
||||
if (sharding.tile_assignment().dim(*it) == 1) {
|
||||
|
|
|
|||
|
|
@ -226,6 +226,16 @@ class SpmdPartitioner : public HloModulePass {
|
|||
int64* next_channel_id, SpmdLogger* logger,
|
||||
SpmdPartitionerOptions options);
|
||||
|
||||
HloInstruction* AllGatherShardsInternal(
|
||||
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
|
||||
int64* next_channel_id, absl::Span<const int64> selected_dims,
|
||||
const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag);
|
||||
HloInstruction* AllReduceAlongShardingDimsInternal(
|
||||
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
|
||||
int64* next_channel_id, absl::Span<const int64> selected_dims,
|
||||
const SPMDCollectiveOpsCreator& collectives_creator,
|
||||
HloComputation* reduction, bool per_dim_ar);
|
||||
|
||||
// Verify that the sharding of instructions in the module are valid, and also
|
||||
// fill in missing sharding information.
|
||||
Status PreprocessSharding(HloModule* module);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user