[XLA:GPU] Use a single intra-host ragged-all-to-all in the decomposition.

Instead of 2 ra2a + concat, we can double the output buffer and adjust output offsets. This way we can save on latency by having only one multi-GPU synchronization.

PiperOrigin-RevId: 826122665
This commit is contained in:
Oleg Shyshkov 2025-10-30 11:43:00 -07:00 committed by TensorFlower Gardener
parent 9b51864c7b
commit 9eeebc9be5
2 changed files with 59 additions and 49 deletions

View File

@ -254,61 +254,75 @@ absl::StatusOr<bool> DecomposeCombineRaggedAllToAll(
HloInstruction::CreateConstant(LiteralUtil::Zero(
ragged_all_to_all->operand(1)->shape().element_type())));
Shape tmp_output_shape = ragged_all_to_all->shape();
tmp_output_shape.set_dimensions(0,
num_hosts * tmp_output_shape.dimensions(0));
auto* zero_broadcast =
computation->AddInstruction(HloInstruction::CreateBroadcast(
/*shape=*/ragged_all_to_all->operand(1)->shape(), zero,
/*broadcast_dimensions=*/{}));
/*shape=*/tmp_output_shape, zero, /*broadcast_dimensions=*/{}));
int64_t num_updates_per_host =
ragged_all_to_all->operand(2)->shape().dimensions(0) / num_hosts;
int64_t num_devices_in_replica_per_host = num_devices_in_replica / num_hosts;
auto slice_metadata_operand = [&](int64_t host_id,
HloInstruction* metadata_operand) {
Shape slice_shape = metadata_operand->shape();
slice_shape.set_dimensions(0, num_updates_per_host);
int64_t num_updates_per_replica =
ragged_all_to_all->operand(2)->shape().dimensions(0) /
num_devices_in_replica;
return computation->AddInstruction(HloInstruction::CreateSlice(
/*shape=*/slice_shape,
/*operand=*/metadata_operand,
/*start_indices=*/{num_updates_per_host * host_id},
/*limit_indices=*/{num_updates_per_host * (host_id + 1)},
/*strides=*/{1}));
auto get_intra_host_metadata = [&](HloInstruction* metadata_operand,
bool correct_offsets) {
metadata_operand =
computation->AddInstruction(HloInstruction::CreateReshape(
/*shape=*/ShapeUtil::MakeShape(
metadata_operand->shape().element_type(),
{num_hosts, num_devices_in_replica_per_host,
num_updates_per_replica}),
/*operand=*/metadata_operand));
if (correct_offsets) {
metadata_operand =
CorrectOffsets(ragged_all_to_all->operand(1)->shape().dimensions(0),
metadata_operand, computation);
}
metadata_operand =
computation->AddInstruction(HloInstruction::CreateTranspose(
/*shape=*/ShapeUtil::MakeShape(
metadata_operand->shape().element_type(),
{num_devices_in_replica_per_host, num_hosts,
num_updates_per_replica}),
/*operand=*/metadata_operand,
/*dimensions=*/{1, 0, 2}));
return computation->AddInstruction(HloInstruction::CreateReshape(
/*shape=*/ragged_all_to_all->operand(2)->shape(),
/*operand=*/metadata_operand));
};
absl::InlinedVector<HloInstruction*, 4> intra_host_ragged_all_to_alls(
num_hosts);
for (int64_t host_id = 0; host_id < num_hosts; ++host_id) {
absl::InlinedVector<HloInstruction*, 4> ragged_all_to_all_operands{
ragged_all_to_all->mutable_operand(0),
zero_broadcast,
slice_metadata_operand(host_id, ragged_all_to_all->mutable_operand(2)),
slice_metadata_operand(host_id, ragged_all_to_all->mutable_operand(3)),
slice_metadata_operand(host_id, ragged_all_to_all->mutable_operand(4)),
slice_metadata_operand(host_id, ragged_all_to_all->mutable_operand(5)),
};
absl::InlinedVector<HloInstruction*, 4> intra_host_ragged_all_to_all_operands{
ragged_all_to_all->mutable_operand(0),
zero_broadcast,
get_intra_host_metadata(ragged_all_to_all->mutable_operand(2),
/*correct_offsets=*/false),
get_intra_host_metadata(ragged_all_to_all->mutable_operand(3),
/*correct_offsets=*/false),
get_intra_host_metadata(ragged_all_to_all->mutable_operand(4),
/*correct_offsets=*/true),
get_intra_host_metadata(ragged_all_to_all->mutable_operand(5),
/*correct_offsets=*/false),
};
intra_host_ragged_all_to_alls[host_id] =
computation->AddInstruction(HloInstruction::CreateRaggedAllToAll(
/*shape=*/ragged_all_to_all->shape(),
/*operands=*/ragged_all_to_all_operands,
/*replica_groups=*/intra_host_replica_groups,
/*channel_id=*/ragged_all_to_all->channel_id().has_value()
? std::make_optional(NextChannelId(*computation->parent()))
: std::nullopt));
}
Shape concatenated_inputs_shape = ragged_all_to_all->shape();
concatenated_inputs_shape.set_dimensions(
0, num_hosts * ragged_all_to_all->shape().dimensions(0));
HloInstruction* concatenated_inputs =
computation->AddInstruction(HloInstruction::CreateConcatenate(
/*shape=*/concatenated_inputs_shape,
/*operands=*/intra_host_ragged_all_to_alls, /*dimension=*/0));
HloInstruction* intra_host_ragged_all_to_all =
computation->AddInstruction(HloInstruction::CreateRaggedAllToAll(
/*shape=*/zero_broadcast->shape(),
/*operands=*/intra_host_ragged_all_to_all_operands,
/*device_list=*/CollectiveDeviceList(intra_host_replica_groups),
/*channel_id=*/ragged_all_to_all->channel_id().has_value()
? std::make_optional(NextChannelId(*computation->parent()))
: std::nullopt));
HloInstruction* local_inputs =
computation->AddInstruction(HloInstruction::CreateAllToAll(
concatenated_inputs->shape(), {concatenated_inputs},
intra_host_ragged_all_to_all->shape(), {intra_host_ragged_all_to_all},
/*device_list=*/CollectiveDeviceList(inter_host_replica_groups),
/*constrain_layout=*/false,
/*channel_id=*/ragged_all_to_all->channel_id().has_value()
@ -323,8 +337,6 @@ absl::StatusOr<bool> DecomposeCombineRaggedAllToAll(
}
HloInstruction* output_offsets = ragged_all_to_all->mutable_operand(4);
int64_t num_updates_per_replica =
output_offsets->shape().dimensions(0) / num_devices_in_replica;
output_offsets = computation->AddInstruction(HloInstruction::CreateReshape(
/*shape=*/ShapeUtil::MakeShape(
@ -344,8 +356,6 @@ absl::StatusOr<bool> DecomposeCombineRaggedAllToAll(
HloInstruction* corrected_output_offsets = output_offsets;
int64_t num_devices_in_replica_per_host = num_devices_in_replica / num_hosts;
corrected_output_offsets =
computation->AddInstruction(HloInstruction::CreateReshape(
/*shape=*/ShapeUtil::MakeShape(

View File

@ -151,7 +151,7 @@ ENTRY main {
TF_EXPECT_OK(HloCSE(true).Run(module.get()));
EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
// CHECK-COUNT-2: ragged-all-to-all{{.*}}, replica_groups={{[{]}}{0,1,2,3,4,5,6,7},{8,9,10,11,12,13,14,15}{{[}]}}
// CHECK: ragged-all-to-all{{.*}}, replica_groups={{[{]}}{0,1,2,3,4,5,6,7},{8,9,10,11,12,13,14,15}{{[}]}}
// CHECK: all-to-all{{.*}}, replica_groups={{[{]}}{0,8},{1,9},{2,10},{3,11},{4,12},{5,13},{6,14},{7,15}{{[}]}}
// CHECK: all-to-all{{.*}}, replica_groups={{[{]}}{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}{{[}]}}
// CHECK: ragged-all-to-all{{.*}}, replica_groups={{[{]}}{0},{1},{2},{3},{4},{5},{6},{7},{8},{9},{10},{11},{12},{13},{14},{15}{{[}]}}