mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[XLA:GPU] Use a single intra-host ragged-all-to-all in the decomposition.
Instead of 2 ra2a + concat, we can double the output buffer and adjust output offsets. This way we can save on latency by having only one multi-GPU synchronization. PiperOrigin-RevId: 826122665
This commit is contained in:
parent
9b51864c7b
commit
9eeebc9be5
|
|
@ -254,61 +254,75 @@ absl::StatusOr<bool> DecomposeCombineRaggedAllToAll(
|
|||
HloInstruction::CreateConstant(LiteralUtil::Zero(
|
||||
ragged_all_to_all->operand(1)->shape().element_type())));
|
||||
|
||||
Shape tmp_output_shape = ragged_all_to_all->shape();
|
||||
tmp_output_shape.set_dimensions(0,
|
||||
num_hosts * tmp_output_shape.dimensions(0));
|
||||
|
||||
auto* zero_broadcast =
|
||||
computation->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
/*shape=*/ragged_all_to_all->operand(1)->shape(), zero,
|
||||
/*broadcast_dimensions=*/{}));
|
||||
/*shape=*/tmp_output_shape, zero, /*broadcast_dimensions=*/{}));
|
||||
|
||||
int64_t num_updates_per_host =
|
||||
ragged_all_to_all->operand(2)->shape().dimensions(0) / num_hosts;
|
||||
int64_t num_devices_in_replica_per_host = num_devices_in_replica / num_hosts;
|
||||
|
||||
auto slice_metadata_operand = [&](int64_t host_id,
|
||||
HloInstruction* metadata_operand) {
|
||||
Shape slice_shape = metadata_operand->shape();
|
||||
slice_shape.set_dimensions(0, num_updates_per_host);
|
||||
int64_t num_updates_per_replica =
|
||||
ragged_all_to_all->operand(2)->shape().dimensions(0) /
|
||||
num_devices_in_replica;
|
||||
|
||||
return computation->AddInstruction(HloInstruction::CreateSlice(
|
||||
/*shape=*/slice_shape,
|
||||
/*operand=*/metadata_operand,
|
||||
/*start_indices=*/{num_updates_per_host * host_id},
|
||||
/*limit_indices=*/{num_updates_per_host * (host_id + 1)},
|
||||
/*strides=*/{1}));
|
||||
auto get_intra_host_metadata = [&](HloInstruction* metadata_operand,
|
||||
bool correct_offsets) {
|
||||
metadata_operand =
|
||||
computation->AddInstruction(HloInstruction::CreateReshape(
|
||||
/*shape=*/ShapeUtil::MakeShape(
|
||||
metadata_operand->shape().element_type(),
|
||||
{num_hosts, num_devices_in_replica_per_host,
|
||||
num_updates_per_replica}),
|
||||
/*operand=*/metadata_operand));
|
||||
|
||||
if (correct_offsets) {
|
||||
metadata_operand =
|
||||
CorrectOffsets(ragged_all_to_all->operand(1)->shape().dimensions(0),
|
||||
metadata_operand, computation);
|
||||
}
|
||||
|
||||
metadata_operand =
|
||||
computation->AddInstruction(HloInstruction::CreateTranspose(
|
||||
/*shape=*/ShapeUtil::MakeShape(
|
||||
metadata_operand->shape().element_type(),
|
||||
{num_devices_in_replica_per_host, num_hosts,
|
||||
num_updates_per_replica}),
|
||||
/*operand=*/metadata_operand,
|
||||
/*dimensions=*/{1, 0, 2}));
|
||||
|
||||
return computation->AddInstruction(HloInstruction::CreateReshape(
|
||||
/*shape=*/ragged_all_to_all->operand(2)->shape(),
|
||||
/*operand=*/metadata_operand));
|
||||
};
|
||||
|
||||
absl::InlinedVector<HloInstruction*, 4> intra_host_ragged_all_to_alls(
|
||||
num_hosts);
|
||||
for (int64_t host_id = 0; host_id < num_hosts; ++host_id) {
|
||||
absl::InlinedVector<HloInstruction*, 4> ragged_all_to_all_operands{
|
||||
ragged_all_to_all->mutable_operand(0),
|
||||
zero_broadcast,
|
||||
slice_metadata_operand(host_id, ragged_all_to_all->mutable_operand(2)),
|
||||
slice_metadata_operand(host_id, ragged_all_to_all->mutable_operand(3)),
|
||||
slice_metadata_operand(host_id, ragged_all_to_all->mutable_operand(4)),
|
||||
slice_metadata_operand(host_id, ragged_all_to_all->mutable_operand(5)),
|
||||
};
|
||||
absl::InlinedVector<HloInstruction*, 4> intra_host_ragged_all_to_all_operands{
|
||||
ragged_all_to_all->mutable_operand(0),
|
||||
zero_broadcast,
|
||||
get_intra_host_metadata(ragged_all_to_all->mutable_operand(2),
|
||||
/*correct_offsets=*/false),
|
||||
get_intra_host_metadata(ragged_all_to_all->mutable_operand(3),
|
||||
/*correct_offsets=*/false),
|
||||
get_intra_host_metadata(ragged_all_to_all->mutable_operand(4),
|
||||
/*correct_offsets=*/true),
|
||||
get_intra_host_metadata(ragged_all_to_all->mutable_operand(5),
|
||||
/*correct_offsets=*/false),
|
||||
};
|
||||
|
||||
intra_host_ragged_all_to_alls[host_id] =
|
||||
computation->AddInstruction(HloInstruction::CreateRaggedAllToAll(
|
||||
/*shape=*/ragged_all_to_all->shape(),
|
||||
/*operands=*/ragged_all_to_all_operands,
|
||||
/*replica_groups=*/intra_host_replica_groups,
|
||||
/*channel_id=*/ragged_all_to_all->channel_id().has_value()
|
||||
? std::make_optional(NextChannelId(*computation->parent()))
|
||||
: std::nullopt));
|
||||
}
|
||||
|
||||
Shape concatenated_inputs_shape = ragged_all_to_all->shape();
|
||||
concatenated_inputs_shape.set_dimensions(
|
||||
0, num_hosts * ragged_all_to_all->shape().dimensions(0));
|
||||
|
||||
HloInstruction* concatenated_inputs =
|
||||
computation->AddInstruction(HloInstruction::CreateConcatenate(
|
||||
/*shape=*/concatenated_inputs_shape,
|
||||
/*operands=*/intra_host_ragged_all_to_alls, /*dimension=*/0));
|
||||
HloInstruction* intra_host_ragged_all_to_all =
|
||||
computation->AddInstruction(HloInstruction::CreateRaggedAllToAll(
|
||||
/*shape=*/zero_broadcast->shape(),
|
||||
/*operands=*/intra_host_ragged_all_to_all_operands,
|
||||
/*device_list=*/CollectiveDeviceList(intra_host_replica_groups),
|
||||
/*channel_id=*/ragged_all_to_all->channel_id().has_value()
|
||||
? std::make_optional(NextChannelId(*computation->parent()))
|
||||
: std::nullopt));
|
||||
|
||||
HloInstruction* local_inputs =
|
||||
computation->AddInstruction(HloInstruction::CreateAllToAll(
|
||||
concatenated_inputs->shape(), {concatenated_inputs},
|
||||
intra_host_ragged_all_to_all->shape(), {intra_host_ragged_all_to_all},
|
||||
/*device_list=*/CollectiveDeviceList(inter_host_replica_groups),
|
||||
/*constrain_layout=*/false,
|
||||
/*channel_id=*/ragged_all_to_all->channel_id().has_value()
|
||||
|
|
@ -323,8 +337,6 @@ absl::StatusOr<bool> DecomposeCombineRaggedAllToAll(
|
|||
}
|
||||
|
||||
HloInstruction* output_offsets = ragged_all_to_all->mutable_operand(4);
|
||||
int64_t num_updates_per_replica =
|
||||
output_offsets->shape().dimensions(0) / num_devices_in_replica;
|
||||
|
||||
output_offsets = computation->AddInstruction(HloInstruction::CreateReshape(
|
||||
/*shape=*/ShapeUtil::MakeShape(
|
||||
|
|
@ -344,8 +356,6 @@ absl::StatusOr<bool> DecomposeCombineRaggedAllToAll(
|
|||
|
||||
HloInstruction* corrected_output_offsets = output_offsets;
|
||||
|
||||
int64_t num_devices_in_replica_per_host = num_devices_in_replica / num_hosts;
|
||||
|
||||
corrected_output_offsets =
|
||||
computation->AddInstruction(HloInstruction::CreateReshape(
|
||||
/*shape=*/ShapeUtil::MakeShape(
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ ENTRY main {
|
|||
TF_EXPECT_OK(HloCSE(true).Run(module.get()));
|
||||
|
||||
EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
|
||||
// CHECK-COUNT-2: ragged-all-to-all{{.*}}, replica_groups={{[{]}}{0,1,2,3,4,5,6,7},{8,9,10,11,12,13,14,15}{{[}]}}
|
||||
// CHECK: ragged-all-to-all{{.*}}, replica_groups={{[{]}}{0,1,2,3,4,5,6,7},{8,9,10,11,12,13,14,15}{{[}]}}
|
||||
// CHECK: all-to-all{{.*}}, replica_groups={{[{]}}{0,8},{1,9},{2,10},{3,11},{4,12},{5,13},{6,14},{7,15}{{[}]}}
|
||||
// CHECK: all-to-all{{.*}}, replica_groups={{[{]}}{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}{{[}]}}
|
||||
// CHECK: ragged-all-to-all{{.*}}, replica_groups={{[{]}}{0},{1},{2},{3},{4},{5},{6},{7},{8},{9},{10},{11},{12},{13},{14},{15}{{[}]}}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user