PiperOrigin-RevId: 822724128
This commit is contained in:
Hyeontaek Lim 2025-10-22 13:43:58 -07:00 committed by TensorFlower Gardener
parent aeda5dabd4
commit 70111bb38f
7 changed files with 140 additions and 5 deletions

View File

@ -1,5 +1,9 @@
# PJRT C API changelog
## 0.81
* Added `PJRT_Layouts_PJRT_Executable_GetOutputLayouts`.
## 0.80
* Added `PJRT_Extension_Type::PJRT_Extension_Type_HostAllocator`.

View File

@ -103,7 +103,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next);
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 80
#define PJRT_API_MINOR 81
// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in

View File

@ -26,8 +26,9 @@ extern "C" {
#endif
// This extension provides capabilities around custom on-device memory layouts
// for PJRT_Buffers. The extension is both optional and experimental, meaning
// ABI-breaking and other incompatible changes may be introduced at any time.
// for PJRT_Buffers and PJRT_Executables. The extension is both optional and
// experimental, meaning ABI-breaking and other incompatible changes may be
// introduced at any time.
//
// If this extension is provided, JAX and possibly other frameworks will assume
// that the compiler MLIR input can contain "mhlo.layout_mode" attributes on
@ -36,7 +37,7 @@ extern "C" {
// https://github.com/openxla/xla/blob/main/xla/pjrt/layout_mode.h for more
// details.
#define PJRT_API_LAYOUTS_EXTENSION_VERSION 2
#define PJRT_API_LAYOUTS_EXTENSION_VERSION 3
// -------------------------------- Data types ---------------------------------
@ -124,6 +125,23 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_PJRT_Topology_GetDefaultLayout_Args,
typedef PJRT_Error* PJRT_Layouts_PJRT_Topology_GetDefaultLayout(
PJRT_Layouts_PJRT_Topology_GetDefaultLayout_Args* args);
// Returns output layouts for an executable.
struct PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args {
size_t struct_size;
PJRT_Extension_Base* extension_start;
PJRT_Executable* executable;
size_t num_outputs; // out
// Layout data is owned by and has the lifetime of `executable`.
// Has length `num_outputs`.
PJRT_Layouts_MemoryLayout** layouts; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args,
layouts);
// Returns a list of layouts for executable outputs. Each output has a layout.
typedef PJRT_Error* PJRT_Layouts_PJRT_Executable_GetOutputLayouts(
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args* args);
// --------------------------- Extension entrypoint ----------------------------
typedef struct PJRT_Layouts_Extension {
@ -136,9 +154,11 @@ typedef struct PJRT_Layouts_Extension {
PJRT_Layouts_PJRT_Buffer_MemoryLayout* PJRT_Layouts_PJRT_Buffer_MemoryLayout;
PJRT_Layouts_PJRT_Topology_GetDefaultLayout*
PJRT_Layouts_PJRT_Topology_GetDefaultLayout;
PJRT_Layouts_PJRT_Executable_GetOutputLayouts*
PJRT_Layouts_PJRT_Executable_GetOutputLayouts;
} PJRT_Layouts_Extension;
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_Extension,
PJRT_Layouts_PJRT_Topology_GetDefaultLayout);
PJRT_Layouts_PJRT_Executable_GetOutputLayouts);
#ifdef __cplusplus
}

View File

@ -219,6 +219,33 @@ static absl::Status EnsureExecutableOutputDimensionsPopulated(
return absl::OkStatus();
}
static absl::Status PopulateExecutableOutputLayouts(
PJRT_Executable* executable) {
TF_ASSIGN_OR_RETURN(
std::vector<std::shared_ptr<const xla::PjRtLayout>> cpp_out_layouts,
executable->get()->GetOutputLayouts());
executable->out_layouts.reserve(cpp_out_layouts.size());
executable->out_layouts_pointers.reserve(cpp_out_layouts.size());
for (std::shared_ptr<const xla::PjRtLayout>& layout : cpp_out_layouts) {
executable->out_layouts.push_back(
PJRT_Layouts_MemoryLayout{std::move(layout)});
}
for (PJRT_Layouts_MemoryLayout& layout : executable->out_layouts) {
executable->out_layouts_pointers.push_back(&layout);
}
return absl::OkStatus();
}
static absl::Status EnsureExecutableOutputLayoutsPopulated(
PJRT_Executable* executable) {
absl::MutexLock lock(executable->mutex);
if (!executable->out_layouts_ran) {
TF_RETURN_IF_ERROR(PopulateExecutableOutputLayouts(executable));
executable->out_layouts_ran = true;
}
return absl::OkStatus();
}
static absl::Status PopulateExecutableOutputMemoryKinds(
PJRT_Executable* executable) {
TF_ASSIGN_OR_RETURN(
@ -2692,6 +2719,19 @@ PJRT_Error* PJRT_Layouts_PJRT_Topology_GetDefaultLayout(
return nullptr;
}
PJRT_Error* PJRT_Layouts_PJRT_Executable_GetOutputLayouts(
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args* args) {
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
"PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args",
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args_STRUCT_SIZE,
args->struct_size));
PJRT_RETURN_IF_ERROR(
EnsureExecutableOutputLayoutsPopulated(args->executable));
args->num_outputs = args->executable->out_layouts_pointers.size();
args->layouts = args->executable->out_layouts_pointers.data();
return nullptr;
}
static std::vector<PJRT_NamedValue> PopulatePjrtAttributes(
const absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>&
attributes) {
@ -3106,6 +3146,8 @@ PJRT_Layouts_Extension CreateLayoutsExtension(PJRT_Extension_Base* next) {
pjrt::PJRT_Layouts_PJRT_Buffer_MemoryLayout,
/*PJRT_Layouts_PJRT_Topology_GetDefaultLayout=*/
&PJRT_Layouts_PJRT_Topology_GetDefaultLayout,
/*PJRT_Layouts_PJRT_Executable_GetOutputLayouts=*/
&PJRT_Layouts_PJRT_Executable_GetOutputLayouts,
};
}

View File

@ -158,6 +158,10 @@ struct PJRT_Executable {
std::vector<int64_t> out_dimensions;
std::vector<size_t> out_dimension_sizes;
bool out_layouts_ran ABSL_GUARDED_BY(mutex) = false;
std::vector<PJRT_Layouts_MemoryLayout> out_layouts;
std::vector<PJRT_Layouts_MemoryLayout*> out_layouts_pointers;
explicit PJRT_Executable(std::shared_ptr<xla::PjRtExecutable> executable);
explicit PJRT_Executable(xla::PjRtExecutable* executable);

View File

@ -1725,6 +1725,63 @@ PjRtCApiExecutable::GetOutputDimensions() const {
return std::vector<std::vector<DimensionVector>>{std::move(out)};
}
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
PjRtCApiExecutable::GetOutputLayouts() const {
const PJRT_Api* c_api = pjrt_c_api();
if (c_api->pjrt_api_version.major_version == 0 &&
c_api->pjrt_api_version.minor_version < 81) {
// If the PJRT C API version is too old, fall back to the default
// implementation.
return this->PjRtExecutable::GetOutputLayouts();
}
PJRT_Layouts_Extension* extension =
pjrt::FindExtension<PJRT_Layouts_Extension>(
c_api, PJRT_Extension_Type::PJRT_Extension_Type_Layouts);
if (extension == nullptr ||
extension->PJRT_Layouts_MemoryLayout_Serialize == nullptr ||
extension->PJRT_Layouts_PJRT_Executable_GetOutputLayouts == nullptr) {
// If we can't find PJRT_Layouts_PJRT_Executable_GetOutputLayouts support,
// fall back to the default implementation.
return this->PjRtExecutable::GetOutputLayouts();
}
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args args;
args.struct_size =
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args_STRUCT_SIZE;
args.extension_start = nullptr;
args.executable = c_executable();
RETURN_STATUS_IF_PJRT_ERROR(
extension->PJRT_Layouts_PJRT_Executable_GetOutputLayouts(&args), c_api);
std::vector<std::shared_ptr<const PjRtLayout>> layouts;
layouts.reserve(args.num_outputs);
for (int i = 0; i < args.num_outputs; ++i) {
// TODO(b/343274093): returns a PjRtLayout that wraps a C API layout
// directly instead of de/serializing into an xla::Layout.
PJRT_Layouts_MemoryLayout_Serialize_Args serialize_args;
serialize_args.struct_size =
PJRT_Layouts_MemoryLayout_Serialize_Args_STRUCT_SIZE;
serialize_args.extension_start = nullptr;
serialize_args.layout = args.layouts[i];
pjrt::LogFatalIfPjrtError(
extension->PJRT_Layouts_MemoryLayout_Serialize(&serialize_args), c_api);
// Clean up `PJRT_Layouts_SerializedLayout`.
absl::Cleanup cleanup = [&serialize_args] {
serialize_args.serialized_layout_deleter(
serialize_args.serialized_layout);
};
std::string serialized_layout(serialize_args.serialized_bytes,
serialize_args.serialized_bytes_size);
absl::StatusOr<std::shared_ptr<const PjRtLayout>> pjrt_layout =
PjRtLayout::Deserialize(serialized_layout);
TF_CHECK_OK(pjrt_layout.status());
layouts.push_back(*std::move(pjrt_layout));
}
return layouts;
}
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
PjRtCApiExecutable::GetOutputMemoryKinds() const {
PJRT_Executable_OutputMemoryKinds_Args args;

View File

@ -590,6 +590,9 @@ class PjRtCApiExecutable : public PjRtExecutable {
absl::StatusOr<std::vector<std::vector<DimensionVector>>>
GetOutputDimensions() const override;
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
GetOutputLayouts() const override;
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
GetOutputMemoryKinds() const override;
@ -671,6 +674,11 @@ class PjRtCApiLoadedExecutable : public PjRtLoadedExecutable {
return executable_->GetOutputDimensions();
}
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
GetOutputLayouts() const override {
return executable_->GetOutputLayouts();
}
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
GetOutputMemoryKinds() const override {
return executable_->GetOutputMemoryKinds();