[XLA:SPMD] Allows backend to specify whether to use multistep allgather/allreduce

PiperOrigin-RevId: 356773102
Change-Id: Ib66f85ccf7fe8a76c31dc7f6322f3f6e28488d97
This commit is contained in:
Yuanzhong Xu 2021-02-10 10:49:37 -08:00 committed by TensorFlower Gardener
parent 1ccefa8d5e
commit b3cba5a7aa
2 changed files with 52 additions and 5 deletions

View File

@ -3500,6 +3500,15 @@ HloInstruction* SpmdPartitioner::AllGatherShards(
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
int64* next_channel_id, absl::Span<const int64> selected_dims,
const SPMDCollectiveOpsCreator& collectives_creator) {
return AllGatherShardsInternal(b, operand, sharding, next_channel_id,
selected_dims, collectives_creator,
/*per_dim_ag=*/true);
}
HloInstruction* SpmdPartitioner::AllGatherShardsInternal(
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
int64* next_channel_id, absl::Span<const int64> selected_dims,
const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag) {
if (selected_dims.empty()) {
return operand;
}
@ -3513,16 +3522,28 @@ HloInstruction* SpmdPartitioner::AllGatherShards(
auto reshape = b->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand));
HloInstruction* result = reshape;
for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
if (sharding.tile_assignment().dim(*it) == 1) {
continue;
if (per_dim_ag) {
for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
if (sharding.tile_assignment().dim(*it) == 1) {
continue;
}
auto partition_subgroups =
GetPartitionGroupsForReplication(sharding, {*it});
shape[0] *= partition_subgroups[0].size();
result = collectives_creator.create_cross_partition_all_gather(
b, result,
ShapeUtil::MakeShape(operand->shape().element_type(), shape),
partition_subgroups, (*next_channel_id)++,
/*all_gather_dimension=*/0);
}
} else {
auto partition_subgroups =
GetPartitionGroupsForReplication(sharding, {*it});
GetPartitionGroupsForReplication(sharding, selected_dims);
shape[0] *= partition_subgroups[0].size();
result = collectives_creator.create_cross_partition_all_gather(
b, result, ShapeUtil::MakeShape(operand->shape().element_type(), shape),
partition_subgroups, (*next_channel_id)++, /*all_gather_dimension=*/0);
partition_subgroups, (*next_channel_id)++,
/*all_gather_dimension=*/0);
}
// If n > 1 dimensions are partitioned, split the leading dimension to n.
std::vector<int64> tiled_dims;
@ -3579,6 +3600,22 @@ HloInstruction* SpmdPartitioner::AllReduceAlongShardingDims(
int64* next_channel_id, absl::Span<const int64> selected_dims,
const SPMDCollectiveOpsCreator& collectives_creator,
HloComputation* reduction) {
return AllReduceAlongShardingDimsInternal(
b, operand, sharding, next_channel_id, selected_dims, collectives_creator,
reduction, /*per_dim_ar=*/true);
}
HloInstruction* SpmdPartitioner::AllReduceAlongShardingDimsInternal(
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
int64* next_channel_id, absl::Span<const int64> selected_dims,
const SPMDCollectiveOpsCreator& collectives_creator,
HloComputation* reduction, bool per_dim_ar) {
if (!per_dim_ar) {
auto partition_subgroups =
GetPartitionGroupsForReplication(sharding, selected_dims);
return collectives_creator.create_cross_partition_all_reduce(
b, operand, reduction, partition_subgroups, (*next_channel_id)++);
}
auto result = operand;
for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
if (sharding.tile_assignment().dim(*it) == 1) {

View File

@ -226,6 +226,16 @@ class SpmdPartitioner : public HloModulePass {
int64* next_channel_id, SpmdLogger* logger,
SpmdPartitionerOptions options);
HloInstruction* AllGatherShardsInternal(
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
int64* next_channel_id, absl::Span<const int64> selected_dims,
const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag);
HloInstruction* AllReduceAlongShardingDimsInternal(
SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
int64* next_channel_id, absl::Span<const int64> selected_dims,
const SPMDCollectiveOpsCreator& collectives_creator,
HloComputation* reduction, bool per_dim_ar);
// Verify that the sharding of instructions in the module are valid, and also
// fill in missing sharding information.
Status PreprocessSharding(HloModule* module);