mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
PR #32836: [GPU] Dispatch S-curve model to single-partition multi-host topology
Imported from GitHub PR https://github.com/openxla/xla/pull/32836 📝 Summary of Changes Updated SINGLE_HOST communication type to SINGLE_PARTITION (fast-interconnect domain) to meet the need of multi-node NVLink (MNNVL) topology. Piped auto-detected partition size for communication type determination, also exposed partition size in SolGPUCostModel::Config for AOT compilation. 🎯 Justification S-curve model cannot handle NVLink latency, single fast-interconnect domain including MNNVL topology should use latency table model. This PR updates the routing mechanism so that MNNVL will be treated as a single partition, while previously host is assumed equivalent to partition. 🚀 Kind of Contribution ✨ New Feature 📊 Benchmark (for Performance Improvements) N/A 🧪 Unit Tests: Added unit tests for model dispatching mechanism. 🧪 Execution Tests: Behavior unchanged for non-MNNVL topology, N/A. Copybara import of the project: -- a9544375934873f7b888fdb5ff6c9dc6ee8b0e6c by Terry Sun <tesun@nvidia.com>: use partition size for static model dispatching -- e3445a5deb8da10146e90c50da5598f91cfe0a69 by Terry Sun <tesun@nvidia.com>: expose partition size to config -- 212535ce891b8eb96ebb3c1e215a91d2b5035594 by Terry Sun <tesun@nvidia.com>: better modularity -- a9fe8a0f89dea9e2811d76a3570c7398df8dd756 by Terry Sun <tesun@nvidia.com>: better code structure and doc string -- a64a2b5ed1d45d815c6a2c47628b4d9ebb8368bd by Terry Sun <tesun@nvidia.com>: update naming Merging this change closes #32836 PiperOrigin-RevId: 826697791
This commit is contained in:
parent
dad4fb74cd
commit
8134117476
|
|
@ -70,6 +70,11 @@ constexpr char kSolChunkSizeBytes[] = "chunk_size_bytes";
|
|||
// cost model.
|
||||
constexpr char kSolGpusPerNode[] = "gpus_per_node";
|
||||
|
||||
// Defines the partition size (number of devices per fast-interconnect domain)
|
||||
// used by the SoL cost model. This is necessary for AOT compilation when the
|
||||
// partition is larger than a node.
|
||||
constexpr char kSolPartitionSize[] = "partition_size";
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // XLA_SERVICE_COLLECTIVE_UTILS_H_
|
||||
|
|
|
|||
|
|
@ -2012,6 +2012,11 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
|
|||
absl::StatusOr<std::unique_ptr<HloModule>> GpuCompiler::RunHloPasses(
|
||||
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
|
||||
const CompileOptions& options) {
|
||||
// TODO rename slice_size to partition_size in CompileOptions
|
||||
if (options.slice_size > 0) {
|
||||
module->mutable_config().set_partition_size(options.slice_size);
|
||||
}
|
||||
|
||||
const DebugOptions debug_opts = module->config().debug_options();
|
||||
TF_RETURN_IF_ERROR(LoadAutotuneResultsFromFile(debug_opts));
|
||||
bool is_deviceless = options.target_config.has_value() ||
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
|
|||
int num_hosts) {
|
||||
IotaReplicaGroupList iota(1, 1);
|
||||
switch (comm) {
|
||||
case GPUCommunicationType::SINGLE_HOST:
|
||||
case GPUCommunicationType::SINGLE_PARTITION:
|
||||
iota = IotaReplicaGroupList(num_hosts, kNumGpusPerHost);
|
||||
break;
|
||||
case GPUCommunicationType::MULTI_HOST_WORLD_LEVEL:
|
||||
|
|
@ -225,28 +225,28 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
|
|||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_HOST,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/2048,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_HOST,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/2 * 2048,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_HOST,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/2048,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_HOST,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/2 * 2048,
|
||||
|
|
@ -309,28 +309,28 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
|
|||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_HOST,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/2048,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_HOST,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/2 * 2048,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_HOST,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/2048,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_HOST,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/2 * 2048,
|
||||
|
|
@ -393,35 +393,35 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
|
|||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_HOST,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/2048,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_HOST,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
/*network_througput_bytes=*/2 * 2048,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_HOST,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/2048,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_HOST,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/2 * 1024,
|
||||
/*num_nodes=*/4,
|
||||
/*network_througput_bytes=*/2 * 2048,
|
||||
},
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllToAll,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_HOST,
|
||||
/*comm=*/GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/1,
|
||||
/*network_througput_bytes=*/1024,
|
||||
|
|
@ -574,60 +574,61 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
/*expected_duration=*/absl::Milliseconds(2500),
|
||||
},
|
||||
{
|
||||
/*test_name=*/"AR_single_host_aligned_extrapolate_nodes",
|
||||
/*test_name=*/"AR_SINGLE_PARTITION_aligned_extrapolate_nodes",
|
||||
/*spec=*/
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::SINGLE_HOST,
|
||||
GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/8,
|
||||
},
|
||||
/*expected_duration=*/absl::Milliseconds(500),
|
||||
},
|
||||
{
|
||||
/*test_name=*/"AR_single_host_aligned_extrapolate_tensor_size",
|
||||
/*test_name=*/"AR_SINGLE_PARTITION_aligned_extrapolate_tensor_size",
|
||||
/*spec=*/
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::SINGLE_HOST,
|
||||
GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/4 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
/*expected_duration=*/absl::Seconds(1),
|
||||
},
|
||||
{
|
||||
/*test_name=*/"AR_single_host_aligned_interpolate_nodes",
|
||||
/*test_name=*/"AR_SINGLE_PARTITION_aligned_interpolate_nodes",
|
||||
/*spec=*/
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::SINGLE_HOST,
|
||||
GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/3,
|
||||
},
|
||||
/*expected_duration=*/absl::Milliseconds(500),
|
||||
},
|
||||
{
|
||||
/*test_name=*/"AR_single_host_aligned_interpolate_tensor_size",
|
||||
/*test_name=*/"AR_SINGLE_PARTITION_aligned_interpolate_tensor_size",
|
||||
/*spec=*/
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduce,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::SINGLE_HOST,
|
||||
GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024 + 256,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
/*expected_duration=*/absl::Milliseconds(625),
|
||||
},
|
||||
{
|
||||
/*test_name=*/"ARS_single_host_aligned_interpolate_tensor_size",
|
||||
/*test_name=*/"ARS_SINGLE_PARTITION_aligned_interpolate_tensor_"
|
||||
"size",
|
||||
/*spec=*/
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllReduceStart,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::SINGLE_HOST,
|
||||
GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024 + 256,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -754,48 +755,48 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
/*expected_duration=*/absl::Milliseconds(2500),
|
||||
},
|
||||
{
|
||||
/*test_name=*/"RS_single_host_aligned_extrapolate_nodes",
|
||||
/*test_name=*/"RS_SINGLE_PARTITION_aligned_extrapolate_nodes",
|
||||
/*spec=*/
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::SINGLE_HOST,
|
||||
GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/8,
|
||||
},
|
||||
/*expected_duration=*/absl::Milliseconds(500),
|
||||
},
|
||||
{
|
||||
/*test_name=*/"RS_single_host_aligned_extrapolate_tensor_size",
|
||||
/*test_name=*/"RS_SINGLE_PARTITION_aligned_extrapolate_tensor_size",
|
||||
/*spec=*/
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::SINGLE_HOST,
|
||||
GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/4 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
/*expected_duration=*/absl::Seconds(1),
|
||||
},
|
||||
{
|
||||
/*test_name=*/"RS_single_host_aligned_interpolate_nodes",
|
||||
/*test_name=*/"RS_SINGLE_PARTITION_aligned_interpolate_nodes",
|
||||
/*spec=*/
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::SINGLE_HOST,
|
||||
GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/3,
|
||||
},
|
||||
/*expected_duration=*/absl::Milliseconds(500),
|
||||
},
|
||||
{
|
||||
/*test_name=*/"RS_single_host_aligned_interpolate_tensor_size",
|
||||
/*test_name=*/"RS_SINGLE_PARTITION_aligned_interpolate_tensor_size",
|
||||
/*spec=*/
|
||||
{
|
||||
/*opcode=*/HloOpcode::kReduceScatter,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::SINGLE_HOST,
|
||||
GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024 + 256,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -922,60 +923,61 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
/*expected_duration=*/absl::Milliseconds(2500),
|
||||
},
|
||||
{
|
||||
/*test_name=*/"AG_single_host_aligned_extrapolate_nodes",
|
||||
/*test_name=*/"AG_SINGLE_PARTITION_aligned_extrapolate_nodes",
|
||||
/*spec=*/
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::SINGLE_HOST,
|
||||
GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/8,
|
||||
},
|
||||
/*expected_duration=*/absl::Milliseconds(500),
|
||||
},
|
||||
{
|
||||
/*test_name=*/"AG_single_host_aligned_extrapolate_tensor_size",
|
||||
/*test_name=*/"AG_SINGLE_PARTITION_aligned_extrapolate_tensor_size",
|
||||
/*spec=*/
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::SINGLE_HOST,
|
||||
GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/4 * 1024,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
/*expected_duration=*/absl::Seconds(1),
|
||||
},
|
||||
{
|
||||
/*test_name=*/"AG_single_host_aligned_interpolate_nodes",
|
||||
/*test_name=*/"AG_SINGLE_PARTITION_aligned_interpolate_nodes",
|
||||
/*spec=*/
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::SINGLE_HOST,
|
||||
GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/3,
|
||||
},
|
||||
/*expected_duration=*/absl::Milliseconds(500),
|
||||
},
|
||||
{
|
||||
/*test_name=*/"AG_single_host_aligned_interpolate_tensor_size",
|
||||
/*test_name=*/"AG_SINGLE_PARTITION_aligned_interpolate_tensor_size",
|
||||
/*spec=*/
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGather,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::SINGLE_HOST,
|
||||
GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024 + 256,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
/*expected_duration=*/absl::Milliseconds(625),
|
||||
},
|
||||
{
|
||||
/*test_name=*/"AGS_single_host_aligned_interpolate_tensor_size",
|
||||
/*test_name=*/"AGS_SINGLE_PARTITION_aligned_interpolate_tensor_"
|
||||
"size",
|
||||
/*spec=*/
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllGatherStart,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::SINGLE_HOST,
|
||||
GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024 + 256,
|
||||
/*num_nodes=*/2,
|
||||
},
|
||||
|
|
@ -1004,11 +1006,11 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
/*expected_duration=*/absl::Milliseconds(250),
|
||||
},
|
||||
{
|
||||
/*test_name=*/"A2A_single_host_exact_match",
|
||||
/*test_name=*/"A2A_SINGLE_PARTITION_exact_match",
|
||||
{
|
||||
/*opcode=*/HloOpcode::kAllToAll,
|
||||
/*comm=*/
|
||||
GPUCommunicationType::SINGLE_HOST,
|
||||
GPUCommunicationType::SINGLE_PARTITION,
|
||||
/*tensor_size=*/1024,
|
||||
/*num_nodes=*/1,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -138,6 +138,9 @@ SolGPUCostModel::Config GetPlatformConfig(
|
|||
} else if (option_name == kSolChunkSizeBytes &&
|
||||
absl::SimpleAtoi(option_value, &value) && value > 0) {
|
||||
config.chunk_size_bytes = value;
|
||||
} else if (option_name == kSolPartitionSize &&
|
||||
absl::SimpleAtoi(option_value, &value) && value > 0) {
|
||||
config.partition_size = value;
|
||||
}
|
||||
}
|
||||
return config;
|
||||
|
|
|
|||
|
|
@ -40,6 +40,8 @@ class SolGPUCostModel {
|
|||
absl::Duration rtt;
|
||||
int64_t gpus_per_node;
|
||||
int64_t chunk_size_bytes;
|
||||
// Partition size (devices per fast-interconnect domain). 0 means unset.
|
||||
int64_t partition_size;
|
||||
};
|
||||
|
||||
enum CollectiveAlgorithmType {
|
||||
|
|
|
|||
|
|
@ -189,6 +189,17 @@ absl::StatusOr<absl::Duration> DCNCollectiveDuration(
|
|||
return result;
|
||||
}
|
||||
|
||||
int64_t GetPartitionSize(const HloInstruction& instr,
|
||||
const SolGPUCostModel::Config& sol_flags) {
|
||||
if (sol_flags.partition_size > 0) {
|
||||
return sol_flags.partition_size;
|
||||
}
|
||||
if (instr.GetModule()->config().partition_size() > 0) {
|
||||
return instr.GetModule()->config().partition_size();
|
||||
}
|
||||
return sol_flags.gpus_per_node;
|
||||
}
|
||||
|
||||
absl::StatusOr<absl::Duration> DispatchEstimation(
|
||||
const absl::StatusOr<GPUCommunicationType>& communication_type,
|
||||
const HloCollectiveInstruction& instr,
|
||||
|
|
@ -202,11 +213,12 @@ absl::StatusOr<absl::Duration> DispatchEstimation(
|
|||
GPUCommunicationType comm = *communication_type;
|
||||
TF_ASSIGN_OR_RETURN(auto num_groups_and_devices,
|
||||
GetReplicaGroupCountAndSize(&instr));
|
||||
int64_t partition_size = GetPartitionSize(instr, sol_flags);
|
||||
|
||||
switch (comm) {
|
||||
case GPUCommunicationType::MULTI_HOST_WORLD_LEVEL: {
|
||||
return DCNCollectiveDuration(
|
||||
num_groups_and_devices->second / sol_flags.gpus_per_node,
|
||||
num_groups_and_devices->second / partition_size,
|
||||
/*num_communicators=*/num_groups_and_devices->first, instr,
|
||||
gpu_device_info, sol_flags, analysis, symbolic_expr_context);
|
||||
}
|
||||
|
|
@ -216,10 +228,11 @@ absl::StatusOr<absl::Duration> DispatchEstimation(
|
|||
/*num_communicators=*/num_groups_and_devices->first, instr,
|
||||
gpu_device_info, sol_flags, analysis, symbolic_expr_context);
|
||||
}
|
||||
case GPUCommunicationType::SINGLE_HOST: {
|
||||
case GPUCommunicationType::SINGLE_PARTITION: {
|
||||
if (collective_interpolator == nullptr) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Collective interpolator is required for single host collectives");
|
||||
"Collective interpolator is required for single partition "
|
||||
"collectives");
|
||||
}
|
||||
return collective_interpolator->EstimatedRuntime(instr);
|
||||
}
|
||||
|
|
@ -309,9 +322,10 @@ SolLatencyEstimator::ComputeCollectiveTime(
|
|||
absl::StrCat("Unsupported collective instruction: ", instr.ToString()));
|
||||
}
|
||||
|
||||
int64_t partition_size = GetPartitionSize(*collective_instr, sol_flags);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
GPUCommunicationType communication_type,
|
||||
CommunicationType(sol_flags.gpus_per_node, *collective_instr,
|
||||
CommunicationType(partition_size, *collective_instr,
|
||||
gpu_device_info.gpu_compute_capability()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
absl::Duration result,
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ absl::StatusOr<bool> CollectiveBackendAssigner::Run(
|
|||
<< " slice_size_=" << slice_size_;
|
||||
bool use_nvshmem =
|
||||
(num_visible_devices_per_process_ == 1 ||
|
||||
comm_type == GPUCommunicationType::SINGLE_HOST ||
|
||||
comm_type == GPUCommunicationType::SINGLE_PARTITION ||
|
||||
(slice_size_ > 0 &&
|
||||
IsIntraNVLinkDomain(module->config(), slice_size_))) &&
|
||||
(!IsAllReduceOp(instr) || shape_size < threshold_in_bytes_);
|
||||
|
|
|
|||
|
|
@ -44,26 +44,26 @@ namespace xla {
|
|||
namespace gpu {
|
||||
namespace {
|
||||
|
||||
// Computes a map from source node ID to a set of target node IDs for a
|
||||
// collective-permute instruction. A node ID is computed by dividing the device
|
||||
// (replica) ID by the number of devices per host.
|
||||
// Computes a map from source partition ID to a set of target partition IDs for
|
||||
// a collective-permute instruction. A partition ID is computed by dividing the
|
||||
// device (replica) ID by the number of devices per host.
|
||||
absl::flat_hash_map<int64_t, absl::flat_hash_set<int64_t>>
|
||||
GetSourceToTargetsNodeMap(const HloCollectivePermuteInstruction& instr,
|
||||
int num_devices_per_host) {
|
||||
int num_devices_per_partition) {
|
||||
absl::flat_hash_map<int64_t, absl::flat_hash_set<int64_t>>
|
||||
source_to_targets_node_map;
|
||||
source_to_targets_partition_map;
|
||||
for (const auto& [source, target] : instr.source_target_pairs()) {
|
||||
int64_t source_node = source / num_devices_per_host;
|
||||
int64_t target_node = target / num_devices_per_host;
|
||||
source_to_targets_node_map[source_node].insert(target_node);
|
||||
int64_t source_partition = source / num_devices_per_partition;
|
||||
int64_t target_partition = target / num_devices_per_partition;
|
||||
source_to_targets_partition_map[source_partition].insert(target_partition);
|
||||
}
|
||||
return source_to_targets_node_map;
|
||||
return source_to_targets_partition_map;
|
||||
}
|
||||
|
||||
struct CollectiveMetadata {
|
||||
// map for ops with `replica_groups`, e.g. all-gather.
|
||||
absl::flat_hash_map<int64_t, size_t> node_to_participant_count;
|
||||
int num_devices_per_host;
|
||||
absl::flat_hash_map<int64_t, size_t> partition_to_participant_count;
|
||||
int num_devices_per_partition;
|
||||
int64_t replica_count;
|
||||
};
|
||||
|
||||
|
|
@ -85,46 +85,48 @@ bool SameParticipantCounts(const absl::flat_hash_map<int64_t, size_t>& lhs,
|
|||
}
|
||||
|
||||
absl::StatusOr<CollectiveMetadata> CommunicationContext(
|
||||
const HloCollectiveInstruction& instr, int num_devices_per_host) {
|
||||
absl::flat_hash_map<int64_t, size_t> node_to_participant_count;
|
||||
const HloCollectiveInstruction& instr, int num_devices_per_partition) {
|
||||
absl::flat_hash_map<int64_t, size_t> partition_to_participant_count;
|
||||
|
||||
for (const ReplicaGroup& replica_group :
|
||||
instr.device_list().replica_groups()) {
|
||||
absl::flat_hash_map<int64_t, size_t> buffer;
|
||||
for (int64_t rank : replica_group.replica_ids()) {
|
||||
int64_t node_id = rank / num_devices_per_host;
|
||||
buffer[node_id]++;
|
||||
int64_t partition_id = rank / num_devices_per_partition;
|
||||
buffer[partition_id]++;
|
||||
}
|
||||
if (!node_to_participant_count.empty() &&
|
||||
!SameParticipantCounts(buffer, node_to_participant_count)) {
|
||||
if (!partition_to_participant_count.empty() &&
|
||||
!SameParticipantCounts(buffer, partition_to_participant_count)) {
|
||||
return absl::FailedPreconditionError(absl::StrCat(
|
||||
"Non homogenous replica group: ", instr.device_list().ToString()));
|
||||
}
|
||||
if (node_to_participant_count.empty()) {
|
||||
node_to_participant_count = buffer;
|
||||
if (partition_to_participant_count.empty()) {
|
||||
partition_to_participant_count = buffer;
|
||||
}
|
||||
}
|
||||
return CollectiveMetadata{node_to_participant_count, num_devices_per_host,
|
||||
return CollectiveMetadata{partition_to_participant_count,
|
||||
num_devices_per_partition,
|
||||
instr.GetModule()->config().replica_count()};
|
||||
}
|
||||
|
||||
bool IsSingleHost(const CollectiveMetadata& pattern) {
|
||||
if (pattern.node_to_participant_count.size() == 1) {
|
||||
if (pattern.partition_to_participant_count.size() == 1) {
|
||||
return true;
|
||||
}
|
||||
return pattern.replica_count > 0 &&
|
||||
pattern.node_to_participant_count.empty() &&
|
||||
pattern.replica_count <= pattern.num_devices_per_host;
|
||||
pattern.partition_to_participant_count.empty() &&
|
||||
pattern.replica_count <= pattern.num_devices_per_partition;
|
||||
}
|
||||
|
||||
bool IsWorldLevelCommunication(const CollectiveMetadata& pattern) {
|
||||
if (!IsSingleHost(pattern) && pattern.node_to_participant_count.empty()) {
|
||||
if (!IsSingleHost(pattern) &&
|
||||
pattern.partition_to_participant_count.empty()) {
|
||||
return true;
|
||||
}
|
||||
return absl::c_all_of(
|
||||
pattern.node_to_participant_count, [&pattern](const auto& elem) {
|
||||
const auto& [node_id, participant_count] = elem;
|
||||
return participant_count == pattern.num_devices_per_host;
|
||||
pattern.partition_to_participant_count, [&pattern](const auto& elem) {
|
||||
const auto& [partition_id, participant_count] = elem;
|
||||
return participant_count == pattern.num_devices_per_partition;
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -143,7 +145,7 @@ bool IsGPUSyncCollective(const HloInstruction& instr) {
|
|||
}
|
||||
|
||||
absl::StatusOr<GPUCommunicationType> CommunicationType(
|
||||
int num_devices_per_host, const HloChannelInstruction& instr,
|
||||
int num_devices_per_partition, const HloChannelInstruction& instr,
|
||||
const se::GpuComputeCapability& gpu_version) {
|
||||
if (!gpu_version.IsCuda()) {
|
||||
return absl::FailedPreconditionError("Only CUDA is supported.");
|
||||
|
|
@ -152,9 +154,9 @@ absl::StatusOr<GPUCommunicationType> CommunicationType(
|
|||
if (const auto* collective = DynCast<HloCollectiveInstruction>(&instr)) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
CollectiveMetadata comm,
|
||||
CommunicationContext(*collective, num_devices_per_host));
|
||||
CommunicationContext(*collective, num_devices_per_partition));
|
||||
if (IsSingleHost(comm)) {
|
||||
return GPUCommunicationType::SINGLE_HOST;
|
||||
return GPUCommunicationType::SINGLE_PARTITION;
|
||||
}
|
||||
if (IsWorldLevelCommunication(comm)) {
|
||||
return GPUCommunicationType::MULTI_HOST_WORLD_LEVEL;
|
||||
|
|
@ -164,19 +166,19 @@ absl::StatusOr<GPUCommunicationType> CommunicationType(
|
|||
}
|
||||
} else if (const auto* collective_permute =
|
||||
DynCast<HloCollectivePermuteInstruction>(&instr)) {
|
||||
const auto source_to_targets_node_map =
|
||||
GetSourceToTargetsNodeMap(*collective_permute, num_devices_per_host);
|
||||
for (const auto& [source_node, target_node_set] :
|
||||
source_to_targets_node_map) {
|
||||
if (target_node_set.size() > 1) {
|
||||
const auto source_to_targets_partition_map = GetSourceToTargetsNodeMap(
|
||||
*collective_permute, num_devices_per_partition);
|
||||
for (const auto& [source_partition, target_partition_set] :
|
||||
source_to_targets_partition_map) {
|
||||
if (target_partition_set.size() > 1) {
|
||||
return GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL;
|
||||
}
|
||||
CHECK_EQ(target_node_set.size(), 1);
|
||||
if (source_node != *target_node_set.begin()) {
|
||||
CHECK_EQ(target_partition_set.size(), 1);
|
||||
if (source_partition != *target_partition_set.begin()) {
|
||||
return GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL;
|
||||
}
|
||||
}
|
||||
return GPUCommunicationType::SINGLE_HOST;
|
||||
return GPUCommunicationType::SINGLE_PARTITION;
|
||||
} else {
|
||||
return absl::FailedPreconditionError(
|
||||
"Cannot determine communication type for non-collective channel "
|
||||
|
|
|
|||
|
|
@ -35,13 +35,13 @@ enum class GPUCommunicationType {
|
|||
// the involved hosts has only a subset of its devices participating.
|
||||
MULTI_HOST_NON_WORLD_LEVEL = 2,
|
||||
// All devices participating in the collective operation reside on the same
|
||||
// host machine.
|
||||
SINGLE_HOST = 3
|
||||
// fast-interconnect domain.
|
||||
SINGLE_PARTITION = 3
|
||||
};
|
||||
|
||||
// Returns the type of communication pattern for a channel instruction.
|
||||
absl::StatusOr<GPUCommunicationType> CommunicationType(
|
||||
int num_devices_per_host, const HloChannelInstruction& instr,
|
||||
int partition_size, const HloChannelInstruction& instr,
|
||||
const se::GpuComputeCapability& gpu_version);
|
||||
|
||||
// Returns true if instruction is a synchronous collective op.
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHost8Devices) {
|
|||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
|
||||
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsSingleHost4Devices) {
|
||||
|
|
@ -85,7 +85,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHost4Devices) {
|
|||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
|
||||
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsSingleHost16Devices) {
|
||||
|
|
@ -106,9 +106,9 @@ TEST_F(CommunicationTypeTest, DetectsSingleHost16Devices) {
|
|||
|
||||
HloCollectiveInstruction* instr = Cast<HloCollectiveInstruction>(
|
||||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
|
||||
EXPECT_THAT(CommunicationType(/*partition_size=*/8, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
|
||||
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectWorldLevelAllDevices) {
|
||||
|
|
@ -201,7 +201,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHost16DevicesForEmptyReplicaGroups) {
|
|||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/16, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
|
||||
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectWorldLevel8DevicesForEmptyReplicaGroups) {
|
||||
|
|
@ -263,7 +263,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermute) {
|
|||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
|
||||
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermuteSinglePair) {
|
||||
|
|
@ -283,7 +283,7 @@ TEST_F(CommunicationTypeTest, DetectsSingleHostCollectivePermuteSinglePair) {
|
|||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*num_devices_per_host=*/8, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::SINGLE_HOST));
|
||||
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectNonWorldLevelCollectivePermute) {
|
||||
|
|
@ -355,5 +355,130 @@ TEST_F(CommunicationTypeTest, DetectsCrossHostCollectivePermuteMixed) {
|
|||
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsSinglePartitionMultiHost) {
|
||||
// 16 devices across 2 hosts with partition_size=16 (single partition spanning
|
||||
// 2 hosts)
|
||||
absl::string_view kHlo = R"(
|
||||
HloModule m, num_partitions=16
|
||||
|
||||
ENTRY e {
|
||||
p = f32[128] parameter(0)
|
||||
ROOT _ = f32[2048] all-gather(p),
|
||||
dimensions={0},
|
||||
use_global_device_ids=true,
|
||||
channel_id=1,
|
||||
replica_groups=[1,16]<=[16]
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo));
|
||||
|
||||
HloCollectiveInstruction* instr = Cast<HloCollectiveInstruction>(
|
||||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*partition_size=*/16, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsMultiPartitionWith8DevicePartitions) {
|
||||
// 64 devices across 2 partitions with partition_size=32
|
||||
absl::string_view kHlo = R"(
|
||||
HloModule m, num_partitions=16
|
||||
|
||||
ENTRY e {
|
||||
p = f32[128] parameter(0)
|
||||
ROOT _ = f32[2048] all-gather(p),
|
||||
dimensions={0},
|
||||
use_global_device_ids=true,
|
||||
channel_id=1,
|
||||
replica_groups=[1, 64]<=[64]
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo));
|
||||
|
||||
HloCollectiveInstruction* instr = Cast<HloCollectiveInstruction>(
|
||||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*partition_size=*/32, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsMultiPartitionNonRailAligned) {
|
||||
// 64 devices with partition_size=36: partition 0 has 36 devices, partition 1
|
||||
// has 28 devices
|
||||
absl::string_view kHlo = R"(
|
||||
HloModule m, num_partitions=12
|
||||
|
||||
ENTRY e {
|
||||
p = f32[128] parameter(0)
|
||||
ROOT _ = f32[1536] all-gather(p),
|
||||
dimensions={0},
|
||||
use_global_device_ids=true,
|
||||
channel_id=1,
|
||||
replica_groups=[1, 64]<=[64]
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo));
|
||||
|
||||
HloCollectiveInstruction* instr = Cast<HloCollectiveInstruction>(
|
||||
module->entry_computation()->root_instruction());
|
||||
// With partition_size=8, spans 2 partitions but not rail-aligned (8 and 4
|
||||
// devices)
|
||||
EXPECT_THAT(CommunicationType(/*partition_size=*/36, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_NON_WORLD_LEVEL));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsSinglePartitionSubset) {
|
||||
// 6 devices within a single partition (partition_size=36)
|
||||
absl::string_view kHlo = R"(
|
||||
HloModule m, num_partitions=4
|
||||
|
||||
ENTRY e {
|
||||
p = f32[128] parameter(0)
|
||||
ROOT _ = f32[512] all-gather(p),
|
||||
dimensions={0},
|
||||
use_global_device_ids=true,
|
||||
channel_id=1,
|
||||
replica_groups={{0,1,2,3,4,5}}
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo));
|
||||
|
||||
HloCollectiveInstruction* instr = Cast<HloCollectiveInstruction>(
|
||||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*partition_size=*/36, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::SINGLE_PARTITION));
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTypeTest, DetectsRailAlignedMultiPartition) {
|
||||
// 128 devices across 2 partitions with partition_size=8 (rail-aligned: 64
|
||||
// devices per partition)
|
||||
absl::string_view kHlo = R"(
|
||||
HloModule m, num_partitions=32
|
||||
|
||||
ENTRY e {
|
||||
p = f32[128] parameter(0)
|
||||
ROOT _ = f32[4096] all-gather(p),
|
||||
dimensions={0},
|
||||
use_global_device_ids=true,
|
||||
channel_id=1,
|
||||
replica_groups=[1,128]<=[128]
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kHlo));
|
||||
|
||||
HloCollectiveInstruction* instr = Cast<HloCollectiveInstruction>(
|
||||
module->entry_computation()->root_instruction());
|
||||
EXPECT_THAT(CommunicationType(/*partition_size=*/64, *instr,
|
||||
device_info().gpu_compute_capability()),
|
||||
IsOkAndHolds(GPUCommunicationType::MULTI_HOST_WORLD_LEVEL));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla::gpu
|
||||
|
|
|
|||
|
|
@ -110,6 +110,9 @@ std::string HloModuleConfig::compilation_cache_key() const {
|
|||
StrAppend(&key, "::device_memory_size=", device_memory_size());
|
||||
}
|
||||
StrAppend(&key, "::use_shardy_partitioner=", use_shardy_partitioner());
|
||||
if (partition_size() != 0) {
|
||||
StrAppend(&key, "::partition_size=", partition_size());
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
|
|
@ -339,6 +342,7 @@ HloModuleConfigProto HloModuleConfig::ToProto() const {
|
|||
proto.set_fdo_profile(fdo_profile_);
|
||||
proto.set_device_memory_size(device_memory_size_);
|
||||
proto.set_use_shardy_partitioner(use_shardy_partitioner_);
|
||||
proto.set_partition_size(partition_size_);
|
||||
*proto.mutable_sharding_config() = ShardingConfig::ToProto(sharding_config_);
|
||||
*proto.mutable_schedule_config() = ScheduleConfig::ToProto(schedule_config_);
|
||||
return proto;
|
||||
|
|
@ -418,6 +422,7 @@ HloModuleConfig::CreateFromProto(const HloModuleConfigProto& proto) {
|
|||
config->fdo_profile_ = proto.fdo_profile();
|
||||
config->device_memory_size_ = proto.device_memory_size();
|
||||
config->use_shardy_partitioner_ = proto.use_shardy_partitioner();
|
||||
config->partition_size_ = proto.partition_size();
|
||||
config->sharding_config_ = ShardingConfig::FromProto(proto.sharding_config());
|
||||
config->schedule_config_ = ScheduleConfig::FromProto(proto.schedule_config());
|
||||
return std::move(config);
|
||||
|
|
|
|||
|
|
@ -451,6 +451,12 @@ class HloModuleConfig {
|
|||
use_shardy_partitioner_ = use_shardy_partitioner;
|
||||
}
|
||||
|
||||
// Number of devices in a fast-interconnect domain.
|
||||
int64_t partition_size() const { return partition_size_; }
|
||||
void set_partition_size(int64_t partition_size) {
|
||||
partition_size_ = partition_size;
|
||||
}
|
||||
|
||||
// Do channel IDs in this module carry semantic information.
|
||||
bool ChannelIdSensitive() const {
|
||||
// TODO(b/430952564): Base this on num_partitions / num_replicas instead
|
||||
|
|
@ -625,6 +631,9 @@ class HloModuleConfig {
|
|||
|
||||
bool use_shardy_partitioner_ = false;
|
||||
|
||||
// Number of devices in a fast-interconnect domain.
|
||||
int64_t partition_size_ = 0;
|
||||
|
||||
// Sharding configuration, where sharding_config_.nodes[v] controls the
|
||||
// sharding of operation v.
|
||||
ShardingConfig sharding_config_;
|
||||
|
|
|
|||
4
third_party/xla/xla/xla.proto
vendored
4
third_party/xla/xla/xla.proto
vendored
|
|
@ -1606,7 +1606,7 @@ message ExecutionOptions {
|
|||
// Serialization of HloModuleConfig. See the C++ class definition for
|
||||
// descriptions of each field.
|
||||
// There are no guarantees of backwards or forwards compatibility.
|
||||
// Next id: 42.
|
||||
// Next id: 43.
|
||||
message HloModuleConfigProto {
|
||||
enum FusionConfigCollection {
|
||||
OFF = 0; // Do not collect configuration.
|
||||
|
|
@ -1671,6 +1671,8 @@ message HloModuleConfigProto {
|
|||
bool use_shardy_partitioner = 34;
|
||||
ShardingConfigProto sharding_config = 38;
|
||||
ScheduleConfigProto schedule_config = 41;
|
||||
// Number of devices in a fast-interconnect domain.
|
||||
int64 partition_size = 42;
|
||||
}
|
||||
|
||||
message HloModuleProtoWithConfig {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user