mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Add PJRT c sandwich benchmarks to nanort benchmarks.
PiperOrigin-RevId: 823259666
This commit is contained in:
parent
4ed3ee15e7
commit
5893a54e81
|
|
@ -61,6 +61,9 @@ xla_cc_test(
|
|||
"//xla/hlo/parser:hlo_parser",
|
||||
"//xla/pjrt:pjrt_client",
|
||||
"//xla/pjrt:pjrt_executable",
|
||||
"//xla/pjrt/c_api_client:pjrt_c_api_client",
|
||||
"//xla/pjrt/plugin:plugin_names",
|
||||
"//xla/pjrt/plugin/xla_cpu:cpu_static_registration",
|
||||
"//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
|
||||
"//xla/runtime:device_id",
|
||||
"//xla/service:computation_placer_hdr",
|
||||
|
|
|
|||
|
|
@ -42,8 +42,10 @@ limitations under the License.
|
|||
#include "xla/hlo/parser/hlo_parser.h"
|
||||
#include "xla/literal.h"
|
||||
#include "xla/literal_util.h"
|
||||
#include "xla/pjrt/c_api_client/pjrt_c_api_client.h"
|
||||
#include "xla/pjrt/pjrt_client.h"
|
||||
#include "xla/pjrt/pjrt_executable.h"
|
||||
#include "xla/pjrt/plugin/plugin_names.h"
|
||||
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
|
||||
#include "xla/runtime/device_id.h"
|
||||
#include "xla/service/computation_placer.h"
|
||||
|
|
@ -577,8 +579,11 @@ BENCHMARK_CAPTURE(BM_NanoRtFibonacci, no_thread_pool, std::nullopt);
|
|||
BENCHMARK_CAPTURE(BM_NanoRtFibonacci, thread_pool,
|
||||
std::make_optional<Eigen::ThreadPool>(2));
|
||||
|
||||
static void BM_PjRtAddScalars(benchmark::State& state) {
|
||||
auto client = GetXlaPjrtCpuClient(/*options=*/{});
|
||||
static void BM_PjRtAddScalars(benchmark::State& state,
|
||||
bool use_c_api_sandwich) {
|
||||
auto client = use_c_api_sandwich
|
||||
? xla::GetCApiClient(kCpuPjrtName, /*create_options=*/{})
|
||||
: GetXlaPjrtCpuClient(/*options=*/{});
|
||||
PjRtDevice* device = (*client)->devices().front();
|
||||
PjRtMemorySpace* memory_space = *device->default_memory_space();
|
||||
|
||||
|
|
@ -609,10 +614,13 @@ static void BM_PjRtAddScalars(benchmark::State& state) {
|
|||
}
|
||||
}
|
||||
|
||||
BENCHMARK(BM_PjRtAddScalars);
|
||||
BENCHMARK_CAPTURE(BM_PjRtAddScalars, Direct, false);
|
||||
BENCHMARK_CAPTURE(BM_PjRtAddScalars, CSandwich, true);
|
||||
|
||||
static void BM_PjRtFibonacci(benchmark::State& state) {
|
||||
auto client = GetXlaPjrtCpuClient(/*options=*/{});
|
||||
static void BM_PjRtFibonacci(benchmark::State& state, bool use_c_api_sandwich) {
|
||||
auto client = use_c_api_sandwich
|
||||
? xla::GetCApiClient(kCpuPjrtName, /*create_options=*/{})
|
||||
: GetXlaPjrtCpuClient(/*options=*/{});
|
||||
PjRtDevice* device = (*client)->devices().front();
|
||||
PjRtMemorySpace* memory_space = *device->default_memory_space();
|
||||
|
||||
|
|
@ -643,7 +651,8 @@ static void BM_PjRtFibonacci(benchmark::State& state) {
|
|||
}
|
||||
}
|
||||
|
||||
BENCHMARK(BM_PjRtFibonacci);
|
||||
BENCHMARK_CAPTURE(BM_PjRtFibonacci, Direct, false);
|
||||
BENCHMARK_CAPTURE(BM_PjRtFibonacci, CSandwich, true);
|
||||
|
||||
} // namespace
|
||||
} // namespace xla::cpu
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user