Add PJRT c sandwich benchmarks to nanort benchmarks.

PiperOrigin-RevId: 823259666
This commit is contained in:
Zac Mustin 2025-10-23 17:56:18 -07:00 committed by TensorFlower Gardener
parent 4ed3ee15e7
commit 5893a54e81
2 changed files with 18 additions and 6 deletions

View File

@ -61,6 +61,9 @@ xla_cc_test(
"//xla/hlo/parser:hlo_parser", "//xla/hlo/parser:hlo_parser",
"//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_executable",
"//xla/pjrt/c_api_client:pjrt_c_api_client",
"//xla/pjrt/plugin:plugin_names",
"//xla/pjrt/plugin/xla_cpu:cpu_static_registration",
"//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
"//xla/runtime:device_id", "//xla/runtime:device_id",
"//xla/service:computation_placer_hdr", "//xla/service:computation_placer_hdr",

View File

@ -42,8 +42,10 @@ limitations under the License.
#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/parser/hlo_parser.h"
#include "xla/literal.h" #include "xla/literal.h"
#include "xla/literal_util.h" #include "xla/literal_util.h"
#include "xla/pjrt/c_api_client/pjrt_c_api_client.h"
#include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/plugin/plugin_names.h"
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" #include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
#include "xla/runtime/device_id.h" #include "xla/runtime/device_id.h"
#include "xla/service/computation_placer.h" #include "xla/service/computation_placer.h"
@ -577,8 +579,11 @@ BENCHMARK_CAPTURE(BM_NanoRtFibonacci, no_thread_pool, std::nullopt);
BENCHMARK_CAPTURE(BM_NanoRtFibonacci, thread_pool, BENCHMARK_CAPTURE(BM_NanoRtFibonacci, thread_pool,
std::make_optional<Eigen::ThreadPool>(2)); std::make_optional<Eigen::ThreadPool>(2));
static void BM_PjRtAddScalars(benchmark::State& state) { static void BM_PjRtAddScalars(benchmark::State& state,
auto client = GetXlaPjrtCpuClient(/*options=*/{}); bool use_c_api_sandwich) {
auto client = use_c_api_sandwich
? xla::GetCApiClient(kCpuPjrtName, /*create_options=*/{})
: GetXlaPjrtCpuClient(/*options=*/{});
PjRtDevice* device = (*client)->devices().front(); PjRtDevice* device = (*client)->devices().front();
PjRtMemorySpace* memory_space = *device->default_memory_space(); PjRtMemorySpace* memory_space = *device->default_memory_space();
@ -609,10 +614,13 @@ static void BM_PjRtAddScalars(benchmark::State& state) {
} }
} }
BENCHMARK(BM_PjRtAddScalars); BENCHMARK_CAPTURE(BM_PjRtAddScalars, Direct, false);
BENCHMARK_CAPTURE(BM_PjRtAddScalars, CSandwich, true);
static void BM_PjRtFibonacci(benchmark::State& state) { static void BM_PjRtFibonacci(benchmark::State& state, bool use_c_api_sandwich) {
auto client = GetXlaPjrtCpuClient(/*options=*/{}); auto client = use_c_api_sandwich
? xla::GetCApiClient(kCpuPjrtName, /*create_options=*/{})
: GetXlaPjrtCpuClient(/*options=*/{});
PjRtDevice* device = (*client)->devices().front(); PjRtDevice* device = (*client)->devices().front();
PjRtMemorySpace* memory_space = *device->default_memory_space(); PjRtMemorySpace* memory_space = *device->default_memory_space();
@ -643,7 +651,8 @@ static void BM_PjRtFibonacci(benchmark::State& state) {
} }
} }
BENCHMARK(BM_PjRtFibonacci); BENCHMARK_CAPTURE(BM_PjRtFibonacci, Direct, false);
BENCHMARK_CAPTURE(BM_PjRtFibonacci, CSandwich, true);
} // namespace } // namespace
} // namespace xla::cpu } // namespace xla::cpu