Optimize PJRT_Executable_NumOutputs.

This function causes significant overhead when enabling the C API because of its repeated calls to `GetOutputShapes`.

In this change, we rewrite the function to make use of [`PJRT_Executable.out_dimensions_sizes`](1e5c568070/xla/pjrt/c/pjrt_c_api_wrapper_impl.h (L157)), which is **cached** on `PJRT_Executable`. Therefore, repeated calls to `Execute` become much cheaper. We must, however, make sure `out_dimensions` has been [`Populated`](1e5c568070/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc (L172)), so we do that in `...NumOutputs`. The new `EnsureExecutableOutputDimensionsPopulated` function eliminates code duplication.

**Alternatives**:
* *Calculate `NumOutputs` as-is, but cache the result on `PJRT_Executable` in a new `num_outputs` variable:* This is fine but makes `PJRT_Executable` bigger for no real reason, and we don't benefit from the already-cached `out_dimensions` win described above.
* *Cache `GetOutputShapes` on PJRT_Executable, use that to calculate all output-shape-dependent functions (`PJRT_Executable_OutputDimensions`, `PJRT_Executable_NumOutputs`, `PJRT_Executable_OutputElementTypes`, etc):* Makes `PJRT_Executable` **much** bigger and would require re-writing all these functions to have (arguably) "too much" implementation details.

PiperOrigin-RevId: 814095618
This commit is contained in:
Zac Mustin 2025-10-01 23:58:08 -07:00 committed by TensorFlower Gardener
parent 51652cf914
commit 2a8dcaf97e
2 changed files with 29 additions and 28 deletions

View File

@ -286,6 +286,20 @@ TEST_F(PjrtCApiGpuExecutableTest, GetCompiledMemoryStats) {
EXPECT_EQ(ref_stats.host_temp_size_in_bytes, stats.host_temp_size_in_bytes); EXPECT_EQ(ref_stats.host_temp_size_in_bytes, stats.host_temp_size_in_bytes);
} }
TEST_F(PjrtCApiGpuExecutableTest, GetNumOutputs) {
auto executable = PjrtCApiTestBase::GetExecutable(executable_.get(), api_);
PJRT_Executable_NumOutputs_Args num_outputs_args;
num_outputs_args.struct_size = PJRT_Executable_NumOutputs_Args_STRUCT_SIZE;
num_outputs_args.extension_start = nullptr;
num_outputs_args.executable = executable.get();
LogFatalIfPjrtError(api_->PJRT_Executable_NumOutputs(&num_outputs_args),
api_);
TF_ASSERT_OK_AND_ASSIGN(auto ref_output_shapes,
executable.get()->get()->GetOutputShapes());
EXPECT_EQ(num_outputs_args.num_outputs, ref_output_shapes.size());
}
TEST_F(PjrtCApiGpuExecutableTest, GetDeviceAssignment) { TEST_F(PjrtCApiGpuExecutableTest, GetDeviceAssignment) {
PJRT_LoadedExecutable_GetDeviceAssignment_Args args; PJRT_LoadedExecutable_GetDeviceAssignment_Args args;
args.struct_size = PJRT_LoadedExecutable_GetDeviceAssignment_Args_STRUCT_SIZE; args.struct_size = PJRT_LoadedExecutable_GetDeviceAssignment_Args_STRUCT_SIZE;

View File

@ -209,6 +209,16 @@ static absl::Status PopulateExecutableOutputDimensions(
return absl::OkStatus(); return absl::OkStatus();
} }
static absl::Status EnsureExecutableOutputDimensionsPopulated(
PJRT_Executable* executable) {
absl::MutexLock lock(&executable->mutex);
if (!executable->out_dimension_ran) {
TF_RETURN_IF_ERROR(PopulateExecutableOutputDimensions(executable));
executable->out_dimension_ran = true;
}
return absl::OkStatus();
}
static absl::Status PopulateExecutableOutputMemoryKinds( static absl::Status PopulateExecutableOutputMemoryKinds(
PJRT_Executable* executable) { PJRT_Executable* executable) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
@ -1476,26 +1486,9 @@ PJRT_Error* PJRT_Executable_NumOutputs(PJRT_Executable_NumOutputs_Args* args) {
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
"PJRT_Executable_NumOutputs_Args", "PJRT_Executable_NumOutputs_Args",
PJRT_Executable_NumOutputs_Args_STRUCT_SIZE, args->struct_size)); PJRT_Executable_NumOutputs_Args_STRUCT_SIZE, args->struct_size));
PJRT_ASSIGN_OR_RETURN(std::vector<xla::Shape> output_shapes, PJRT_RETURN_IF_ERROR(
args->executable->get()->GetOutputShapes()); EnsureExecutableOutputDimensionsPopulated(args->executable));
if (output_shapes.empty()) { args->num_outputs = args->executable->out_dimension_sizes.size();
return new PJRT_Error{
xla::InvalidArgument("Can't get number of executable outputs, output "
"shapes is empty for executable %s.",
args->executable->get()->name())};
}
if (output_shapes.size() != 1) {
return new PJRT_Error{
xla::Unimplemented("MPMD execution not supported by PJRT C API (in "
"function PJRT_Executable_NumOutputs).")};
}
const xla::Shape& shape = output_shapes[0];
if (shape.IsTuple()) {
args->num_outputs = shape.tuple_shapes().size();
} else {
// The output size is 1, as it is not a tuple.
args->num_outputs = 1;
}
return nullptr; return nullptr;
} }
@ -1637,14 +1630,8 @@ PJRT_Error* PJRT_Executable_OutputDimensions(
"PJRT_Executable_OutputDimensions_Args", "PJRT_Executable_OutputDimensions_Args",
PJRT_Executable_OutputDimensions_Args_STRUCT_SIZE, args->struct_size)); PJRT_Executable_OutputDimensions_Args_STRUCT_SIZE, args->struct_size));
{ PJRT_RETURN_IF_ERROR(
absl::MutexLock lock(args->executable->mutex); EnsureExecutableOutputDimensionsPopulated(args->executable));
if (!args->executable->out_dimension_ran) {
PJRT_RETURN_IF_ERROR(
PopulateExecutableOutputDimensions(args->executable));
args->executable->out_dimension_ran = true;
}
}
args->num_outputs = args->executable->out_dimension_sizes.size(); args->num_outputs = args->executable->out_dimension_sizes.size();
args->dim_sizes = args->executable->out_dimension_sizes.data(); args->dim_sizes = args->executable->out_dimension_sizes.data();