mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Adds topology utility functions
PiperOrigin-RevId: 821858216
This commit is contained in:
parent
ce507e7993
commit
d2e02ce8d9
16
third_party/xla/xla/pjrt/pjrt_compiler.h
vendored
16
third_party/xla/xla/pjrt/pjrt_compiler.h
vendored
|
|
@ -215,6 +215,22 @@ class PjRtTopologyDescription {
|
|||
}
|
||||
};
|
||||
|
||||
// Returns true if it's TPU topology.
|
||||
inline bool IsTpuTopology(const PjRtTopologyDescription& topology_description) {
|
||||
return topology_description.platform_id() == xla::TpuId();
|
||||
}
|
||||
|
||||
// Returns true if it's GPU topology.
|
||||
inline bool IsGpuTopology(const PjRtTopologyDescription& topology_description) {
|
||||
return topology_description.platform_id() == xla::CudaId() ||
|
||||
topology_description.platform_id() == xla::RocmId();
|
||||
}
|
||||
|
||||
// Returns true if it's CPU topology.
|
||||
inline bool IsCpuTopology(const PjRtTopologyDescription& topology_description) {
|
||||
return topology_description.platform_id() == xla::CpuId();
|
||||
}
|
||||
|
||||
// Abstract interface that all registered compilers must implement.
|
||||
class PjRtCompiler {
|
||||
public:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user