mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Unify topology in PjRtTopologyDescription
The topology on pjrt layer can be seen as: (process, chip, logical device) or (process, chip, core) For cpu, it is (1, num device, 1) For gpu, it is (num host, gpu per host, 1) PiperOrigin-RevId: 826581627
This commit is contained in:
parent
e0f6a6c7f3
commit
8572aaa4e9
1
third_party/xla/xla/pjrt/BUILD
vendored
1
third_party/xla/xla/pjrt/BUILD
vendored
|
|
@ -479,6 +479,7 @@ cc_library(
|
|||
"//xla/hlo/builder:xla_computation",
|
||||
"//xla/pjrt/proto:pjrt_partial_program_proto_cc",
|
||||
"//xla/pjrt/proto:topology_description_proto_cc",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/log",
|
||||
|
|
|
|||
|
|
@ -86,22 +86,18 @@ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
|
|||
return gpu_topology_->number_of_hosts();
|
||||
}
|
||||
|
||||
absl::StatusOr<int> CoreCountOfDefaultType() const override {
|
||||
return gpu_topology_->number_of_devices();
|
||||
}
|
||||
|
||||
absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const override {
|
||||
return gpu_topology_->number_of_devices();
|
||||
}
|
||||
|
||||
absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const override {
|
||||
return gpu_topology_->number_of_devices();
|
||||
absl::StatusOr<int> ChipsPerProcess() const override {
|
||||
return gpu_topology_->num_devices_per_host();
|
||||
}
|
||||
|
||||
absl::StatusOr<int> CoreCountOfDefaultTypePerChip() const override {
|
||||
return 1;
|
||||
}
|
||||
|
||||
absl::StatusOr<int> LogicalDeviceCountOfDefaultTypePerChip() const override {
|
||||
return 1;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::pair<PjRtDeviceDimensions, int32_t>>
|
||||
LogicalDeviceOfDefaultTypeForId(
|
||||
xla::PjRtGlobalDeviceId device_id) const override;
|
||||
|
|
|
|||
37
third_party/xla/xla/pjrt/pjrt_compiler.h
vendored
37
third_party/xla/xla/pjrt/pjrt_compiler.h
vendored
|
|
@ -37,6 +37,7 @@ limitations under the License.
|
|||
#include "xla/pjrt/pjrt_executable.h"
|
||||
#include "xla/pjrt/proto/pjrt_partial_program.pb.h"
|
||||
#include "xla/pjrt/proto/topology_description.pb.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "tsl/platform/fingerprint.h"
|
||||
|
||||
namespace xla {
|
||||
|
|
@ -121,20 +122,41 @@ class PjRtTopologyDescription {
|
|||
return absl::UnimplementedError("ProcessCount is unsupported.");
|
||||
}
|
||||
|
||||
// Returns the number of chips per process.
|
||||
virtual absl::StatusOr<int> ChipsPerProcess() const {
|
||||
return absl::UnimplementedError("ChipsPerProcess is unsupported.");
|
||||
}
|
||||
|
||||
// Returns the number of chips.
|
||||
virtual absl::StatusOr<int> ChipCount() const {
|
||||
return absl::UnimplementedError("ChipCount is unsupported.");
|
||||
TF_ASSIGN_OR_RETURN(int process_count, ProcessCount());
|
||||
TF_ASSIGN_OR_RETURN(int chips_per_process, ChipsPerProcess());
|
||||
return process_count * chips_per_process;
|
||||
}
|
||||
|
||||
// Returns the total number of cores of the default type.
|
||||
virtual absl::StatusOr<int> CoreCountOfDefaultType() const {
|
||||
return absl::UnimplementedError("CoreCountOfDefaultType is unsupported.");
|
||||
TF_ASSIGN_OR_RETURN(int process_count, ProcessCount());
|
||||
TF_ASSIGN_OR_RETURN(int cores_per_process,
|
||||
CoreCountOfDefaultTypePerProcess());
|
||||
return process_count * cores_per_process;
|
||||
}
|
||||
|
||||
// As above, but returns the number of logical devices per host.
|
||||
virtual absl::StatusOr<int> LogicalDeviceCountOfDefaultTypePerProcess()
|
||||
const {
|
||||
TF_ASSIGN_OR_RETURN(int logical_devices_per_chip,
|
||||
LogicalDeviceCountOfDefaultTypePerChip());
|
||||
TF_ASSIGN_OR_RETURN(int chips_per_process, ChipsPerProcess());
|
||||
return chips_per_process * logical_devices_per_chip;
|
||||
}
|
||||
|
||||
// Returns the total number of logical devices of the default type.
|
||||
virtual absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const {
|
||||
return absl::UnimplementedError(
|
||||
"LogicalDeviceCountOfDefaultType is unsupported.");
|
||||
TF_ASSIGN_OR_RETURN(int process_count, ProcessCount());
|
||||
TF_ASSIGN_OR_RETURN(int logical_devices_per_process,
|
||||
LogicalDeviceCountOfDefaultTypePerProcess());
|
||||
return process_count * logical_devices_per_process;
|
||||
}
|
||||
|
||||
// Returns the number of logical devices of the default type per chip.
|
||||
|
|
@ -145,8 +167,9 @@ class PjRtTopologyDescription {
|
|||
|
||||
// Returns the number of cores of the default type per process.
|
||||
virtual absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const {
|
||||
return absl::UnimplementedError(
|
||||
"CoreCountOfDefaultTypePerProcess is unsupported.");
|
||||
TF_ASSIGN_OR_RETURN(int cores_per_chip, CoreCountOfDefaultTypePerChip());
|
||||
TF_ASSIGN_OR_RETURN(int chips_per_process, ChipsPerProcess());
|
||||
return cores_per_chip * chips_per_process;
|
||||
}
|
||||
|
||||
// Returns the number of cores per chip for the default type.
|
||||
|
|
@ -178,6 +201,8 @@ class PjRtTopologyDescription {
|
|||
}
|
||||
|
||||
// Returns the total bounds of all chips in the topology.
|
||||
// Usually this equals to the product of `ChipsPerHostBounds()` and
|
||||
// `HostBounds()`.
|
||||
virtual absl::StatusOr<PjRtDeviceDimensions> ChipBounds() const {
|
||||
return absl::UnimplementedError("ChipBounds is unsupported.");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||
#include "absl/types/span.h"
|
||||
#include "xla/layout.h"
|
||||
#include "xla/pjrt/pjrt_client.h"
|
||||
#include "xla/pjrt/pjrt_common.h"
|
||||
#include "xla/pjrt/pjrt_compiler.h"
|
||||
#include "xla/pjrt/pjrt_device_description.h"
|
||||
#include "xla/pjrt/pjrt_device_dimensions.h"
|
||||
|
|
@ -80,16 +81,12 @@ class CpuTopologyDescription : public PjRtTopologyDescription {
|
|||
// correctly report process count.
|
||||
absl::StatusOr<int> ProcessCount() const override { return 1; }
|
||||
|
||||
absl::StatusOr<int> CoreCountOfDefaultType() const override {
|
||||
absl::StatusOr<int> ChipsPerProcess() const override {
|
||||
return cpu_topology_.number_of_devices();
|
||||
}
|
||||
|
||||
absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const override {
|
||||
return cpu_topology_.number_of_devices();
|
||||
}
|
||||
|
||||
absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const override {
|
||||
return cpu_topology_.number_of_devices();
|
||||
absl::StatusOr<int> LogicalDeviceCountOfDefaultTypePerChip() const override {
|
||||
return 1;
|
||||
}
|
||||
|
||||
absl::StatusOr<int> CoreCountOfDefaultTypePerChip() const override {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user