[PJRT C] Implement Executable::GetOutputLayouts() in the PJRT Layouts extension

This change implements a native support for `xla::Executable::GetOutputLayouts()` in PJRT C API, when PJRT Layouts extension is available. This support does not fetch the optimized HLO, and thus this method becomes more reliable and fast.

This change strongly recommends the plugin that implemented the Layouts extension v2 to upgrade to v3 to avoid an incompatibility.

PiperOrigin-RevId: 821834116
This commit is contained in:
Hyeontaek Lim 2025-10-20 15:38:24 -07:00 committed by TensorFlower Gardener
parent a40f3bdebd
commit 67e5eafb24
5 changed files with 129 additions and 4 deletions

View File

@ -26,8 +26,9 @@ extern "C" {
#endif
// This extension provides capabilities around custom on-device memory layouts
// for PJRT_Buffers. The extension is both optional and experimental, meaning
// ABI-breaking and other incompatible changes may be introduced at any time.
// for PJRT_Buffers and PJRT_Executables. The extension is both optional and
// experimental, meaning ABI-breaking and other incompatible changes may be
// introduced at any time.
//
// If this extension is provided, JAX and possibly other frameworks will assume
// that the compiler MLIR input can contain "mhlo.layout_mode" attributes on
@ -36,7 +37,7 @@ extern "C" {
// https://github.com/openxla/xla/blob/main/xla/pjrt/layout_mode.h for more
// details.
#define PJRT_API_LAYOUTS_EXTENSION_VERSION 2
#define PJRT_API_LAYOUTS_EXTENSION_VERSION 3
// -------------------------------- Data types ---------------------------------
@ -124,6 +125,23 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_PJRT_Topology_GetDefaultLayout_Args,
typedef PJRT_Error* PJRT_Layouts_PJRT_Topology_GetDefaultLayout(
PJRT_Layouts_PJRT_Topology_GetDefaultLayout_Args* args);
// Returns output layouts for an executable.
struct PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
PJRT_Executable* executable;
size_t num_outputs; // out
// Layout data is owned by and has the lifetime of `executable`.
// Has length `num_outputs`.
PJRT_Layouts_MemoryLayout** layouts; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args,
layouts);
// Returns a list of layouts for executable outputs. Each output has a layout.
typedef PJRT_Error* PJRT_Layouts_PJRT_Executable_GetOutputLayouts(
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args* args);
// --------------------------- Extension entrypoint ----------------------------
typedef struct PJRT_Layouts_Extension {
@ -136,9 +154,11 @@ typedef struct PJRT_Layouts_Extension {
PJRT_Layouts_PJRT_Buffer_MemoryLayout* PJRT_Layouts_PJRT_Buffer_MemoryLayout;
PJRT_Layouts_PJRT_Topology_GetDefaultLayout*
PJRT_Layouts_PJRT_Topology_GetDefaultLayout;
PJRT_Layouts_PJRT_Executable_GetOutputLayouts*
PJRT_Layouts_PJRT_Executable_GetOutputLayouts;
} PJRT_Layouts_Extension;
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_Extension,
PJRT_Layouts_PJRT_Topology_GetDefaultLayout);
PJRT_Layouts_PJRT_Executable_GetOutputLayouts);
#ifdef __cplusplus
}

View File

@ -219,6 +219,33 @@ static absl::Status EnsureExecutableOutputDimensionsPopulated(
return absl::OkStatus();
}
static absl::Status PopulateExecutableOutputLayouts(
PJRT_Executable* executable) {
TF_ASSIGN_OR_RETURN(
std::vector<std::shared_ptr<const xla::PjRtLayout>> cpp_out_layouts,
executable->get()->GetOutputLayouts());
executable->out_layouts.reserve(cpp_out_layouts.size());
executable->out_layouts_pointers.reserve(cpp_out_layouts.size());
for (std::shared_ptr<const xla::PjRtLayout>& layout : cpp_out_layouts) {
executable->out_layouts.push_back(
PJRT_Layouts_MemoryLayout{std::move(layout)});
}
for (PJRT_Layouts_MemoryLayout& layout : executable->out_layouts) {
executable->out_layouts_pointers.push_back(&layout);
}
return absl::OkStatus();
}
static absl::Status EnsureExecutableOutputLayoutsPopulated(
PJRT_Executable* executable) {
absl::MutexLock lock(executable->mutex);
if (!executable->out_layouts_ran) {
TF_RETURN_IF_ERROR(PopulateExecutableOutputLayouts(executable));
executable->out_layouts_ran = true;
}
return absl::OkStatus();
}
static absl::Status PopulateExecutableOutputMemoryKinds(
PJRT_Executable* executable) {
TF_ASSIGN_OR_RETURN(
@ -2692,6 +2719,19 @@ PJRT_Error* PJRT_Layouts_PJRT_Topology_GetDefaultLayout(
return nullptr;
}
PJRT_Error* PJRT_Layouts_PJRT_Executable_GetOutputLayouts(
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args* args) {
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
"PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args",
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args_STRUCT_SIZE,
args->struct_size));
PJRT_RETURN_IF_ERROR(
EnsureExecutableOutputLayoutsPopulated(args->executable));
args->num_outputs = args->executable->out_layouts_pointers.size();
args->layouts = args->executable->out_layouts_pointers.data();
return nullptr;
}
static std::vector<PJRT_NamedValue> PopulatePjrtAttributes(
const absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>&
attributes) {
@ -3106,6 +3146,8 @@ PJRT_Layouts_Extension CreateLayoutsExtension(PJRT_Extension_Base* next) {
pjrt::PJRT_Layouts_PJRT_Buffer_MemoryLayout,
/*PJRT_Layouts_PJRT_Topology_GetDefaultLayout=*/
&PJRT_Layouts_PJRT_Topology_GetDefaultLayout,
/*PJRT_Layouts_PJRT_Executable_GetOutputLayouts=*/
&PJRT_Layouts_PJRT_Executable_GetOutputLayouts,
};
}

View File

@ -158,6 +158,10 @@ struct PJRT_Executable {
std::vector<int64_t> out_dimensions;
std::vector<size_t> out_dimension_sizes;
bool out_layouts_ran ABSL_GUARDED_BY(mutex) = false;
std::vector<PJRT_Layouts_MemoryLayout> out_layouts;
std::vector<PJRT_Layouts_MemoryLayout*> out_layouts_pointers;
explicit PJRT_Executable(std::shared_ptr<xla::PjRtExecutable> executable);
explicit PJRT_Executable(xla::PjRtExecutable* executable);

View File

@ -1725,6 +1725,57 @@ PjRtCApiExecutable::GetOutputDimensions() const {
return std::vector<std::vector<DimensionVector>>{std::move(out)};
}
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
PjRtCApiExecutable::GetOutputLayouts() const {
const PJRT_Api* c_api = pjrt_c_api();
PJRT_Layouts_Extension* extension =
pjrt::FindExtension<PJRT_Layouts_Extension>(
c_api, PJRT_Extension_Type::PJRT_Extension_Type_Layouts);
if (extension == nullptr ||
extension->PJRT_Layouts_MemoryLayout_Serialize == nullptr ||
extension->PJRT_Layouts_PJRT_Executable_GetOutputLayouts == nullptr) {
// If we can't find PJRT_Layouts_PJRT_Executable_GetOutputLayouts support,
// fall back to the default implementation.
return this->PjRtExecutable::GetOutputLayouts();
}
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args args;
args.struct_size =
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args_STRUCT_SIZE;
args.extension_start = nullptr;
args.executable = c_executable();
RETURN_STATUS_IF_PJRT_ERROR(
extension->PJRT_Layouts_PJRT_Executable_GetOutputLayouts(&args), c_api);
std::vector<std::shared_ptr<const PjRtLayout>> layouts;
layouts.reserve(args.num_outputs);
for (int i = 0; i < args.num_outputs; ++i) {
// TODO(b/343274093): returns a PjRtLayout that wraps a C API layout
// directly instead of de/serializing into an xla::Layout.
PJRT_Layouts_MemoryLayout_Serialize_Args serialize_args;
serialize_args.struct_size =
PJRT_Layouts_MemoryLayout_Serialize_Args_STRUCT_SIZE;
serialize_args.extension_start = nullptr;
serialize_args.layout = args.layouts[i];
pjrt::LogFatalIfPjrtError(
extension->PJRT_Layouts_MemoryLayout_Serialize(&serialize_args), c_api);
// Clean up `PJRT_Layouts_SerializedLayout`.
absl::Cleanup cleanup = [&serialize_args] {
serialize_args.serialized_layout_deleter(
serialize_args.serialized_layout);
};
std::string serialized_layout(serialize_args.serialized_bytes,
serialize_args.serialized_bytes_size);
absl::StatusOr<std::shared_ptr<const PjRtLayout>> pjrt_layout =
PjRtLayout::Deserialize(serialized_layout);
TF_CHECK_OK(pjrt_layout.status());
layouts.push_back(*std::move(pjrt_layout));
}
return layouts;
}
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
PjRtCApiExecutable::GetOutputMemoryKinds() const {
PJRT_Executable_OutputMemoryKinds_Args args;

View File

@ -590,6 +590,9 @@ class PjRtCApiExecutable : public PjRtExecutable {
absl::StatusOr<std::vector<std::vector<DimensionVector>>>
GetOutputDimensions() const override;
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
GetOutputLayouts() const override;
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
GetOutputMemoryKinds() const override;
@ -671,6 +674,11 @@ class PjRtCApiLoadedExecutable : public PjRtLoadedExecutable {
return executable_->GetOutputDimensions();
}
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
GetOutputLayouts() const override {
return executable_->GetOutputLayouts();
}
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
GetOutputMemoryKinds() const override {
return executable_->GetOutputMemoryKinds();