Unify topology in PjRtTopologyDescription

The topology on pjrt layer can be seen as:

(process, chip, logical device) or (process, chip, core)

For cpu, it is (1, num device, 1)

For gpu, it is (num host, gpu per host, 1)

PiperOrigin-RevId: 826581627
This commit is contained in:
Haibo Huang 2025-10-31 12:13:20 -07:00 committed by TensorFlower Gardener
parent e0f6a6c7f3
commit 8572aaa4e9
4 changed files with 42 additions and 23 deletions

View File

@ -479,6 +479,7 @@ cc_library(
"//xla/hlo/builder:xla_computation",
"//xla/pjrt/proto:pjrt_partial_program_proto_cc",
"//xla/pjrt/proto:topology_description_proto_cc",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",

View File

@ -86,22 +86,18 @@ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
return gpu_topology_->number_of_hosts();
}
absl::StatusOr<int> CoreCountOfDefaultType() const override {
return gpu_topology_->number_of_devices();
}
absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const override {
return gpu_topology_->number_of_devices();
}
absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const override {
return gpu_topology_->number_of_devices();
absl::StatusOr<int> ChipsPerProcess() const override {
return gpu_topology_->num_devices_per_host();
}
absl::StatusOr<int> CoreCountOfDefaultTypePerChip() const override {
return 1;
}
absl::StatusOr<int> LogicalDeviceCountOfDefaultTypePerChip() const override {
return 1;
}
absl::StatusOr<std::pair<PjRtDeviceDimensions, int32_t>>
LogicalDeviceOfDefaultTypeForId(
xla::PjRtGlobalDeviceId device_id) const override;

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/proto/pjrt_partial_program.pb.h"
#include "xla/pjrt/proto/topology_description.pb.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/platform/fingerprint.h"
namespace xla {
@ -121,20 +122,41 @@ class PjRtTopologyDescription {
return absl::UnimplementedError("ProcessCount is unsupported.");
}
// Returns the number of chips per process.
virtual absl::StatusOr<int> ChipsPerProcess() const {
return absl::UnimplementedError("ChipsPerProcess is unsupported.");
}
// Returns the number of chips.
virtual absl::StatusOr<int> ChipCount() const {
return absl::UnimplementedError("ChipCount is unsupported.");
TF_ASSIGN_OR_RETURN(int process_count, ProcessCount());
TF_ASSIGN_OR_RETURN(int chips_per_process, ChipsPerProcess());
return process_count * chips_per_process;
}
// Returns the total number of cores of the default type.
virtual absl::StatusOr<int> CoreCountOfDefaultType() const {
return absl::UnimplementedError("CoreCountOfDefaultType is unsupported.");
TF_ASSIGN_OR_RETURN(int process_count, ProcessCount());
TF_ASSIGN_OR_RETURN(int cores_per_process,
CoreCountOfDefaultTypePerProcess());
return process_count * cores_per_process;
}
// As above, but returns the number of logical devices per host.
virtual absl::StatusOr<int> LogicalDeviceCountOfDefaultTypePerProcess()
const {
TF_ASSIGN_OR_RETURN(int logical_devices_per_chip,
LogicalDeviceCountOfDefaultTypePerChip());
TF_ASSIGN_OR_RETURN(int chips_per_process, ChipsPerProcess());
return chips_per_process * logical_devices_per_chip;
}
// Returns the total number of logical devices of the default type.
virtual absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const {
return absl::UnimplementedError(
"LogicalDeviceCountOfDefaultType is unsupported.");
TF_ASSIGN_OR_RETURN(int process_count, ProcessCount());
TF_ASSIGN_OR_RETURN(int logical_devices_per_process,
LogicalDeviceCountOfDefaultTypePerProcess());
return process_count * logical_devices_per_process;
}
// Returns the number of logical devices of the default type per chip.
@ -145,8 +167,9 @@ class PjRtTopologyDescription {
// Returns the number of cores of the default type per process.
virtual absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const {
return absl::UnimplementedError(
"CoreCountOfDefaultTypePerProcess is unsupported.");
TF_ASSIGN_OR_RETURN(int cores_per_chip, CoreCountOfDefaultTypePerChip());
TF_ASSIGN_OR_RETURN(int chips_per_process, ChipsPerProcess());
return cores_per_chip * chips_per_process;
}
// Returns the number of cores per chip for the default type.
@ -178,6 +201,8 @@ class PjRtTopologyDescription {
}
// Returns the total bounds of all chips in the topology.
// Usually this equals to the product of `ChipsPerHostBounds()` and
// `HostBounds()`.
virtual absl::StatusOr<PjRtDeviceDimensions> ChipBounds() const {
return absl::UnimplementedError("ChipBounds is unsupported.");
}

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "xla/layout.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_common.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_device_description.h"
#include "xla/pjrt/pjrt_device_dimensions.h"
@ -80,16 +81,12 @@ class CpuTopologyDescription : public PjRtTopologyDescription {
// correctly report process count.
absl::StatusOr<int> ProcessCount() const override { return 1; }
absl::StatusOr<int> CoreCountOfDefaultType() const override {
absl::StatusOr<int> ChipsPerProcess() const override {
return cpu_topology_.number_of_devices();
}
absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const override {
return cpu_topology_.number_of_devices();
}
absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const override {
return cpu_topology_.number_of_devices();
absl::StatusOr<int> LogicalDeviceCountOfDefaultTypePerChip() const override {
return 1;
}
absl::StatusOr<int> CoreCountOfDefaultTypePerChip() const override {