From 2a8dcaf97ee71696dd08ed6e23c99dcb2c993e3e Mon Sep 17 00:00:00 2001 From: Zac Mustin Date: Wed, 1 Oct 2025 23:58:08 -0700 Subject: [PATCH] Optimize `PJRT_Executable_NumOutputs`. This function causes significant overhead when enabling the C API because of its repeated calls to `GetOutputShapes`. In this change, we rewrite the function to make use of [`PJRT_Executable.out_dimensions_sizes`](https://github.com/openxla/xla/blob/1e5c56807066808cd487159a63691cee7e406df6/xla/pjrt/c/pjrt_c_api_wrapper_impl.h#L157), which is **cached** on `PJRT_Executable`. Therefore, repeated calls to `Execute` become much cheaper. We must, however, make sure `out_dimensions` has been [`Populated`](https://github.com/openxla/xla/blob/1e5c56807066808cd487159a63691cee7e406df6/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc#L172), so we do that in `...NumOutputs`. The new `EnsureExecutableOutputDimensionsPopulated` function eliminates code duplication. **Alternatives**: * *Calculate `NumOutputs` as-is, but cache the result on `PJRT_Executable` in a new `num_outputs` variable:* This is fine but makes `PJRT_Executable` bigger for no real reason, and we don't benefit from the already-cached `out_dimensions` win described above. * *Cache `GetOutputShapes` on PJRT_Executable, use that to calculate all output-shape-dependent functions (`PJRT_Executable_OutputDimensions`, `PJRT_Executable_NumOutputs`, `PJRT_Executable_OutputElementTypes`, etc):* Makes `PJRT_Executable` **much** bigger and would require re-writing all these functions to have (arguably) "too much" implementation details. PiperOrigin-RevId: 814095618 --- .../xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc | 14 ++++++ .../xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 43 +++++++------------ 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc index 7344d292c48..f032c3eb87c 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -286,6 +286,20 @@ TEST_F(PjrtCApiGpuExecutableTest, GetCompiledMemoryStats) { EXPECT_EQ(ref_stats.host_temp_size_in_bytes, stats.host_temp_size_in_bytes); } +TEST_F(PjrtCApiGpuExecutableTest, GetNumOutputs) { + auto executable = PjrtCApiTestBase::GetExecutable(executable_.get(), api_); + PJRT_Executable_NumOutputs_Args num_outputs_args; + num_outputs_args.struct_size = PJRT_Executable_NumOutputs_Args_STRUCT_SIZE; + num_outputs_args.extension_start = nullptr; + num_outputs_args.executable = executable.get(); + LogFatalIfPjrtError(api_->PJRT_Executable_NumOutputs(&num_outputs_args), + api_); + + TF_ASSERT_OK_AND_ASSIGN(auto ref_output_shapes, + executable.get()->get()->GetOutputShapes()); + EXPECT_EQ(num_outputs_args.num_outputs, ref_output_shapes.size()); +} + TEST_F(PjrtCApiGpuExecutableTest, GetDeviceAssignment) { PJRT_LoadedExecutable_GetDeviceAssignment_Args args; args.struct_size = PJRT_LoadedExecutable_GetDeviceAssignment_Args_STRUCT_SIZE; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 61c718e85c7..d582a94d9b1 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -209,6 +209,16 @@ static absl::Status PopulateExecutableOutputDimensions( return absl::OkStatus(); } +static absl::Status EnsureExecutableOutputDimensionsPopulated( + PJRT_Executable* executable) { + absl::MutexLock lock(&executable->mutex); + if (!executable->out_dimension_ran) { + TF_RETURN_IF_ERROR(PopulateExecutableOutputDimensions(executable)); + executable->out_dimension_ran = true; + } + return absl::OkStatus(); +} + static absl::Status PopulateExecutableOutputMemoryKinds( PJRT_Executable* executable) { TF_ASSIGN_OR_RETURN( @@ -1476,26 +1486,9 @@ PJRT_Error* PJRT_Executable_NumOutputs(PJRT_Executable_NumOutputs_Args* args) { PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( "PJRT_Executable_NumOutputs_Args", PJRT_Executable_NumOutputs_Args_STRUCT_SIZE, args->struct_size)); - PJRT_ASSIGN_OR_RETURN(std::vector output_shapes, - args->executable->get()->GetOutputShapes()); - if (output_shapes.empty()) { - return new PJRT_Error{ - xla::InvalidArgument("Can't get number of executable outputs, output " - "shapes is empty for executable %s.", - args->executable->get()->name())}; - } - if (output_shapes.size() != 1) { - return new PJRT_Error{ - xla::Unimplemented("MPMD execution not supported by PJRT C API (in " - "function PJRT_Executable_NumOutputs).")}; - } - const xla::Shape& shape = output_shapes[0]; - if (shape.IsTuple()) { - args->num_outputs = shape.tuple_shapes().size(); - } else { - // The output size is 1, as it is not a tuple. - args->num_outputs = 1; - } + PJRT_RETURN_IF_ERROR( + EnsureExecutableOutputDimensionsPopulated(args->executable)); + args->num_outputs = args->executable->out_dimension_sizes.size(); return nullptr; } @@ -1637,14 +1630,8 @@ PJRT_Error* PJRT_Executable_OutputDimensions( "PJRT_Executable_OutputDimensions_Args", PJRT_Executable_OutputDimensions_Args_STRUCT_SIZE, args->struct_size)); - { - absl::MutexLock lock(args->executable->mutex); - if (!args->executable->out_dimension_ran) { - PJRT_RETURN_IF_ERROR( - PopulateExecutableOutputDimensions(args->executable)); - args->executable->out_dimension_ran = true; - } - } + PJRT_RETURN_IF_ERROR( + EnsureExecutableOutputDimensionsPopulated(args->executable)); args->num_outputs = args->executable->out_dimension_sizes.size(); args->dim_sizes = args->executable->out_dimension_sizes.data();