mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Adjust the collective-permute cross host type to MULTI_HOST_NON_WORLD_LEVEL only.
PiperOrigin-RevId: 826327580
This commit is contained in:
parent
d90723f48e
commit
d9c76aafeb
|
|
@ -105,7 +105,9 @@ cc_library(
|
|||
"//xla/tsl/platform:statusor",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
|||
|
|
@ -18,11 +18,12 @@ limitations under the License.
|
|||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
|
|
@ -43,7 +44,24 @@ namespace xla {
|
|||
namespace gpu {
|
||||
namespace {
|
||||
|
||||
struct CommunicationMetadata {
|
||||
// Computes a map from source node ID to a set of target node IDs for a
|
||||
// collective-permute instruction. A node ID is computed by dividing the device
|
||||
// (replica) ID by the number of devices per host.
|
||||
absl::flat_hash_map<int64_t, absl::flat_hash_set<int64_t>>
|
||||
GetSourceToTargetsNodeMap(const HloCollectivePermuteInstruction& instr,
|
||||
int num_devices_per_host) {
|
||||
absl::flat_hash_map<int64_t, absl::flat_hash_set<int64_t>>
|
||||
source_to_targets_node_map;
|
||||
for (const auto& [source, target] : instr.source_target_pairs()) {
|
||||
int64_t source_node = source / num_devices_per_host;
|
||||
int64_t target_node = target / num_devices_per_host;
|
||||
source_to_targets_node_map[source_node].insert(target_node);
|
||||
}
|
||||
return source_to_targets_node_map;
|
||||
}
|
||||
|
||||
struct CollectiveMetadata {
|
||||
// map for ops with `replica_groups`, e.g. all-gather.
|
||||
absl::flat_hash_map<int64_t, size_t> node_to_participant_count;
|
||||
int num_devices_per_host;
|
||||
int64_t replica_count;
|
||||
|
|
@ -66,14 +84,12 @@ bool SameParticipantCounts(const absl::flat_hash_map<int64_t, size_t>& lhs,
|
|||
return lhs_counts == rhs_counts;
|
||||
}
|
||||
|
||||
absl::StatusOr<CommunicationMetadata> CommunicationContext(
|
||||
const HloChannelInstruction& instr, int num_devices_per_host) {
|
||||
absl::StatusOr<CollectiveMetadata> CommunicationContext(
|
||||
const HloCollectiveInstruction& instr, int num_devices_per_host) {
|
||||
absl::flat_hash_map<int64_t, size_t> node_to_participant_count;
|
||||
|
||||
if (const HloCollectiveInstruction* collective =
|
||||
DynCast<HloCollectiveInstruction>(&instr)) {
|
||||
for (const ReplicaGroup& replica_group :
|
||||
collective->device_list().replica_groups()) {
|
||||
instr.device_list().replica_groups()) {
|
||||
absl::flat_hash_map<int64_t, size_t> buffer;
|
||||
for (int64_t rank : replica_group.replica_ids()) {
|
||||
int64_t node_id = rank / num_devices_per_host;
|
||||
|
|
@ -81,34 +97,18 @@ absl::StatusOr<CommunicationMetadata> CommunicationContext(
|
|||
}
|
||||
if (!node_to_participant_count.empty() &&
|
||||
!SameParticipantCounts(buffer, node_to_participant_count)) {
|
||||
return absl::FailedPreconditionError(
|
||||
absl::StrCat("Non homogenous replica group: ",
|
||||
collective->device_list().ToString()));
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"Non homogenous replica group: ", instr.device_list().ToString()));
|
||||
}
|
||||
if (node_to_participant_count.empty()) {
|
||||
node_to_participant_count = buffer;
|
||||
}
|
||||
}
|
||||
} else if (const HloCollectivePermuteInstruction* collective_permute =
|
||||
DynCast<HloCollectivePermuteInstruction>(&instr)) {
|
||||
for (const auto& [source, target] :
|
||||
collective_permute->source_target_pairs()) {
|
||||
int64_t source_node = source / num_devices_per_host;
|
||||
int64_t target_node = target / num_devices_per_host;
|
||||
node_to_participant_count[source_node]++;
|
||||
node_to_participant_count[target_node]++;
|
||||
}
|
||||
} else {
|
||||
return absl::FailedPreconditionError(
|
||||
"Cannot determine communication context for non-collective channel "
|
||||
"instruction");
|
||||
}
|
||||
|
||||
return CommunicationMetadata{node_to_participant_count, num_devices_per_host,
|
||||
return CollectiveMetadata{node_to_participant_count, num_devices_per_host,
|
||||
instr.GetModule()->config().replica_count()};
|
||||
}
|
||||
|
||||
bool IsSingleHost(const CommunicationMetadata& pattern) {
|
||||
bool IsSingleHost(const CollectiveMetadata& pattern) {
|
||||
if (pattern.node_to_participant_count.size() == 1) {
|
||||
return true;
|
||||
}
|
||||
|
|
@ -117,7 +117,7 @@ bool IsSingleHost(const CommunicationMetadata& pattern) {
|
|||
pattern.replica_count <= pattern.num_devices_per_host;
|
||||
}
|
||||
|
||||
bool IsWorldLevelCommunication(const CommunicationMetadata& pattern) {
|
||||
bool IsWorldLevelCommunication(const CollectiveMetadata& pattern) {
|
||||
if (!IsSingleHost(pattern) && pattern.node_to_participant_count.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
|
@ -128,7 +128,7 @@ bool IsWorldLevelCommunication(const CommunicationMetadata& pattern) {
|
|||
});
|
||||
}
|
||||
|
||||
bool IsNonWorldLevelCommunication(const CommunicationMetadata& pattern) {
|
||||
bool IsNonWorldLevelCommunication(const CollectiveMetadata& pattern) {
|
||||
return !IsSingleHost(pattern) && !IsWorldLevelCommunication(pattern);
|
||||
}
|
||||
|
||||
|
|
@ -149,8 +149,10 @@ absl::StatusOr<GPUCommunicationType> CommunicationType(
|
|||
return absl::FailedPreconditionError("Only CUDA is supported.");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(CommunicationMetadata comm,
|
||||
CommunicationContext(instr, num_devices_per_host));
|
||||
if (const auto* collective = DynCast<HloCollectiveInstruction>(&instr)) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
CollectiveMetadata comm,
|
||||
CommunicationContext(*collective, num_devices_per_host));
|
||||
if (IsSingleHost(comm)) {
|
||||
return GPUCommunicationType::SINGLE_HOST;
|
||||
}
|
||||
|
|
@ -160,6 +162,26 @@ absl::StatusOr<GPUCommunicationType> CommunicationType(
|
|||
if (IsNonWorldLevelCommunication(comm)) {
|
||||
return GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL;
|
||||
}
|
||||
} else if (const auto* collective_permute =
|
||||
DynCast<HloCollectivePermuteInstruction>(&instr)) {
|
||||
const auto source_to_targets_node_map =
|
||||
GetSourceToTargetsNodeMap(*collective_permute, num_devices_per_host);
|
||||
for (const auto& [source_node, target_node_set] :
|
||||
source_to_targets_node_map) {
|
||||
if (target_node_set.size() > 1) {
|
||||
return GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL;
|
||||
}
|
||||
CHECK_EQ(target_node_set.size(), 1);
|
||||
if (source_node != *target_node_set.begin()) {
|
||||
return GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL;
|
||||
}
|
||||
}
|
||||
return GPUCommunicationType::SINGLE_HOST;
|
||||
} else {
|
||||
return absl::FailedPreconditionError(
|
||||
"Cannot determine communication type for non-collective channel "
|
||||
"instruction");
|
||||
}
|
||||
|
||||
return GPUCommunicationType::UNDEFINED;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -266,6 +266,26 @@ TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermute) {
|
|||
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermuteSinglePair) {
|
||||
absl::string_view kHlo = R"(
|
||||
HloModule m, num_partitions=8
|
||||
|
||||
ENTRY e {
|
||||
p = f32[128] parameter(0)
|
||||
ROOT _ = f32[128] collective-permute(p),
|
||||
source_target_pairs={{0,7},{7,0}}
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo));
|
||||
|
||||
HloChannelInstruction* instr = Cast<HloChannelInstruction>(
|
||||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectNonWorldLevelCollectivePermute) {
|
||||
absl::string_view kHlo = R"(
|
||||
HloModule m, num_partitions=16
|
||||
|
|
@ -304,7 +324,35 @@ TEST_F(CommunicationTypeTest, DetectWorldLevelCollectivePermute) {
|
|||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL));
|
||||
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsCrossHostCollectivePermuteMixed) {
|
||||
absl::string_view kHlo = R"(
|
||||
HloModule m, num_partitions=16
|
||||
|
||||
ENTRY e {
|
||||
p = f32[128] parameter(0)
|
||||
ROOT _ = f32[128] collective-permute(p),
|
||||
source_target_pairs={{0,7},
|
||||
{0,8},
|
||||
{1,9},
|
||||
{2,10},
|
||||
{3,11},
|
||||
{4,12},
|
||||
{5,13},
|
||||
{6,14},
|
||||
{7,15}}
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo));
|
||||
|
||||
HloChannelInstruction* instr = Cast<HloChannelInstruction>(
|
||||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user