Adds topology utility functions

PiperOrigin-RevId: 821858216
This commit is contained in:
Haibo Huang 2025-10-20 16:47:55 -07:00 committed by TensorFlower Gardener
parent ce507e7993
commit d2e02ce8d9

View File

@ -215,6 +215,22 @@ class PjRtTopologyDescription {
}
};
// Returns true if it's TPU topology.
inline bool IsTpuTopology(const PjRtTopologyDescription& topology_description) {
return topology_description.platform_id() == xla::TpuId();
}
// Returns true if it's GPU topology.
inline bool IsGpuTopology(const PjRtTopologyDescription& topology_description) {
return topology_description.platform_id() == xla::CudaId() ||
topology_description.platform_id() == xla::RocmId();
}
// Returns true if it's CPU topology.
inline bool IsCpuTopology(const PjRtTopologyDescription& topology_description) {
return topology_description.platform_id() == xla::CpuId();
}
// Abstract interface that all registered compilers must implement.
class PjRtCompiler {
public: