mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Optimize PJRT_Executable_NumOutputs.
This function causes significant overhead when enabling the C API because of its repeated calls to `GetOutputShapes`. In this change, we rewrite the function to make use of [`PJRT_Executable.out_dimensions_sizes`](1e5c568070/xla/pjrt/c/pjrt_c_api_wrapper_impl.h (L157)), which is **cached** on `PJRT_Executable`. Therefore, repeated calls to `Execute` become much cheaper. We must, however, make sure `out_dimensions` has been [`Populated`](1e5c568070/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc (L172)), so we do that in `...NumOutputs`. The new `EnsureExecutableOutputDimensionsPopulated` function eliminates code duplication. **Alternatives**: * *Calculate `NumOutputs` as-is, but cache the result on `PJRT_Executable` in a new `num_outputs` variable:* This is fine but makes `PJRT_Executable` bigger for no real reason, and we don't benefit from the already-cached `out_dimensions` win described above. * *Cache `GetOutputShapes` on PJRT_Executable, use that to calculate all output-shape-dependent functions (`PJRT_Executable_OutputDimensions`, `PJRT_Executable_NumOutputs`, `PJRT_Executable_OutputElementTypes`, etc):* Makes `PJRT_Executable` **much** bigger and would require re-writing all these functions to have (arguably) "too much" implementation details. PiperOrigin-RevId: 814095618
This commit is contained in:
parent
51652cf914
commit
2a8dcaf97e
|
|
@ -286,6 +286,20 @@ TEST_F(PjrtCApiGpuExecutableTest, GetCompiledMemoryStats) {
|
|||
EXPECT_EQ(ref_stats.host_temp_size_in_bytes, stats.host_temp_size_in_bytes);
|
||||
}
|
||||
|
||||
TEST_F(PjrtCApiGpuExecutableTest, GetNumOutputs) {
|
||||
auto executable = PjrtCApiTestBase::GetExecutable(executable_.get(), api_);
|
||||
PJRT_Executable_NumOutputs_Args num_outputs_args;
|
||||
num_outputs_args.struct_size = PJRT_Executable_NumOutputs_Args_STRUCT_SIZE;
|
||||
num_outputs_args.extension_start = nullptr;
|
||||
num_outputs_args.executable = executable.get();
|
||||
LogFatalIfPjrtError(api_->PJRT_Executable_NumOutputs(&num_outputs_args),
|
||||
api_);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto ref_output_shapes,
|
||||
executable.get()->get()->GetOutputShapes());
|
||||
EXPECT_EQ(num_outputs_args.num_outputs, ref_output_shapes.size());
|
||||
}
|
||||
|
||||
TEST_F(PjrtCApiGpuExecutableTest, GetDeviceAssignment) {
|
||||
PJRT_LoadedExecutable_GetDeviceAssignment_Args args;
|
||||
args.struct_size = PJRT_LoadedExecutable_GetDeviceAssignment_Args_STRUCT_SIZE;
|
||||
|
|
|
|||
|
|
@ -209,6 +209,16 @@ static absl::Status PopulateExecutableOutputDimensions(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
static absl::Status EnsureExecutableOutputDimensionsPopulated(
|
||||
PJRT_Executable* executable) {
|
||||
absl::MutexLock lock(&executable->mutex);
|
||||
if (!executable->out_dimension_ran) {
|
||||
TF_RETURN_IF_ERROR(PopulateExecutableOutputDimensions(executable));
|
||||
executable->out_dimension_ran = true;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
static absl::Status PopulateExecutableOutputMemoryKinds(
|
||||
PJRT_Executable* executable) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
|
|
@ -1476,26 +1486,9 @@ PJRT_Error* PJRT_Executable_NumOutputs(PJRT_Executable_NumOutputs_Args* args) {
|
|||
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
|
||||
"PJRT_Executable_NumOutputs_Args",
|
||||
PJRT_Executable_NumOutputs_Args_STRUCT_SIZE, args->struct_size));
|
||||
PJRT_ASSIGN_OR_RETURN(std::vector<xla::Shape> output_shapes,
|
||||
args->executable->get()->GetOutputShapes());
|
||||
if (output_shapes.empty()) {
|
||||
return new PJRT_Error{
|
||||
xla::InvalidArgument("Can't get number of executable outputs, output "
|
||||
"shapes is empty for executable %s.",
|
||||
args->executable->get()->name())};
|
||||
}
|
||||
if (output_shapes.size() != 1) {
|
||||
return new PJRT_Error{
|
||||
xla::Unimplemented("MPMD execution not supported by PJRT C API (in "
|
||||
"function PJRT_Executable_NumOutputs).")};
|
||||
}
|
||||
const xla::Shape& shape = output_shapes[0];
|
||||
if (shape.IsTuple()) {
|
||||
args->num_outputs = shape.tuple_shapes().size();
|
||||
} else {
|
||||
// The output size is 1, as it is not a tuple.
|
||||
args->num_outputs = 1;
|
||||
}
|
||||
PJRT_RETURN_IF_ERROR(
|
||||
EnsureExecutableOutputDimensionsPopulated(args->executable));
|
||||
args->num_outputs = args->executable->out_dimension_sizes.size();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
@ -1637,14 +1630,8 @@ PJRT_Error* PJRT_Executable_OutputDimensions(
|
|||
"PJRT_Executable_OutputDimensions_Args",
|
||||
PJRT_Executable_OutputDimensions_Args_STRUCT_SIZE, args->struct_size));
|
||||
|
||||
{
|
||||
absl::MutexLock lock(args->executable->mutex);
|
||||
if (!args->executable->out_dimension_ran) {
|
||||
PJRT_RETURN_IF_ERROR(
|
||||
PopulateExecutableOutputDimensions(args->executable));
|
||||
args->executable->out_dimension_ran = true;
|
||||
}
|
||||
}
|
||||
PJRT_RETURN_IF_ERROR(
|
||||
EnsureExecutableOutputDimensionsPopulated(args->executable));
|
||||
|
||||
args->num_outputs = args->executable->out_dimension_sizes.size();
|
||||
args->dim_sizes = args->executable->out_dimension_sizes.data();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user