mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[PJRT C] Implement Executable::GetOutputLayouts() in the PJRT Layouts extension
This change implements a native support for `xla::Executable::GetOutputLayouts()` in PJRT C API, when PJRT Layouts extension is available. This support does not fetch the optimized HLO, and thus this method becomes more reliable and fast. This change strongly recommends the plugin that implemented the Layouts extension v2 to upgrade to v3 to avoid an incompatibility. PiperOrigin-RevId: 821834116
This commit is contained in:
parent
a40f3bdebd
commit
67e5eafb24
|
|
@ -26,8 +26,9 @@ extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// This extension provides capabilities around custom on-device memory layouts
|
// This extension provides capabilities around custom on-device memory layouts
|
||||||
// for PJRT_Buffers. The extension is both optional and experimental, meaning
|
// for PJRT_Buffers and PJRT_Executables. The extension is both optional and
|
||||||
// ABI-breaking and other incompatible changes may be introduced at any time.
|
// experimental, meaning ABI-breaking and other incompatible changes may be
|
||||||
|
// introduced at any time.
|
||||||
//
|
//
|
||||||
// If this extension is provided, JAX and possibly other frameworks will assume
|
// If this extension is provided, JAX and possibly other frameworks will assume
|
||||||
// that the compiler MLIR input can contain "mhlo.layout_mode" attributes on
|
// that the compiler MLIR input can contain "mhlo.layout_mode" attributes on
|
||||||
|
|
@ -36,7 +37,7 @@ extern "C" {
|
||||||
// https://github.com/openxla/xla/blob/main/xla/pjrt/layout_mode.h for more
|
// https://github.com/openxla/xla/blob/main/xla/pjrt/layout_mode.h for more
|
||||||
// details.
|
// details.
|
||||||
|
|
||||||
#define PJRT_API_LAYOUTS_EXTENSION_VERSION 2
|
#define PJRT_API_LAYOUTS_EXTENSION_VERSION 3
|
||||||
|
|
||||||
// -------------------------------- Data types ---------------------------------
|
// -------------------------------- Data types ---------------------------------
|
||||||
|
|
||||||
|
|
@ -124,6 +125,23 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_PJRT_Topology_GetDefaultLayout_Args,
|
||||||
typedef PJRT_Error* PJRT_Layouts_PJRT_Topology_GetDefaultLayout(
|
typedef PJRT_Error* PJRT_Layouts_PJRT_Topology_GetDefaultLayout(
|
||||||
PJRT_Layouts_PJRT_Topology_GetDefaultLayout_Args* args);
|
PJRT_Layouts_PJRT_Topology_GetDefaultLayout_Args* args);
|
||||||
|
|
||||||
|
// Returns output layouts for an executable.
|
||||||
|
struct PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args {
|
||||||
|
size_t struct_size;
|
||||||
|
PJRT_Extension_Base* extension_start;
|
||||||
|
PJRT_Executable* executable;
|
||||||
|
size_t num_outputs; // out
|
||||||
|
// Layout data is owned by and has the lifetime of `executable`.
|
||||||
|
// Has length `num_outputs`.
|
||||||
|
PJRT_Layouts_MemoryLayout** layouts; // out
|
||||||
|
};
|
||||||
|
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args,
|
||||||
|
layouts);
|
||||||
|
|
||||||
|
// Returns a list of layouts for executable outputs. Each output has a layout.
|
||||||
|
typedef PJRT_Error* PJRT_Layouts_PJRT_Executable_GetOutputLayouts(
|
||||||
|
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args* args);
|
||||||
|
|
||||||
// --------------------------- Extension entrypoint ----------------------------
|
// --------------------------- Extension entrypoint ----------------------------
|
||||||
|
|
||||||
typedef struct PJRT_Layouts_Extension {
|
typedef struct PJRT_Layouts_Extension {
|
||||||
|
|
@ -136,9 +154,11 @@ typedef struct PJRT_Layouts_Extension {
|
||||||
PJRT_Layouts_PJRT_Buffer_MemoryLayout* PJRT_Layouts_PJRT_Buffer_MemoryLayout;
|
PJRT_Layouts_PJRT_Buffer_MemoryLayout* PJRT_Layouts_PJRT_Buffer_MemoryLayout;
|
||||||
PJRT_Layouts_PJRT_Topology_GetDefaultLayout*
|
PJRT_Layouts_PJRT_Topology_GetDefaultLayout*
|
||||||
PJRT_Layouts_PJRT_Topology_GetDefaultLayout;
|
PJRT_Layouts_PJRT_Topology_GetDefaultLayout;
|
||||||
|
PJRT_Layouts_PJRT_Executable_GetOutputLayouts*
|
||||||
|
PJRT_Layouts_PJRT_Executable_GetOutputLayouts;
|
||||||
} PJRT_Layouts_Extension;
|
} PJRT_Layouts_Extension;
|
||||||
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_Extension,
|
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_Extension,
|
||||||
PJRT_Layouts_PJRT_Topology_GetDefaultLayout);
|
PJRT_Layouts_PJRT_Executable_GetOutputLayouts);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -219,6 +219,33 @@ static absl::Status EnsureExecutableOutputDimensionsPopulated(
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static absl::Status PopulateExecutableOutputLayouts(
|
||||||
|
PJRT_Executable* executable) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::vector<std::shared_ptr<const xla::PjRtLayout>> cpp_out_layouts,
|
||||||
|
executable->get()->GetOutputLayouts());
|
||||||
|
executable->out_layouts.reserve(cpp_out_layouts.size());
|
||||||
|
executable->out_layouts_pointers.reserve(cpp_out_layouts.size());
|
||||||
|
for (std::shared_ptr<const xla::PjRtLayout>& layout : cpp_out_layouts) {
|
||||||
|
executable->out_layouts.push_back(
|
||||||
|
PJRT_Layouts_MemoryLayout{std::move(layout)});
|
||||||
|
}
|
||||||
|
for (PJRT_Layouts_MemoryLayout& layout : executable->out_layouts) {
|
||||||
|
executable->out_layouts_pointers.push_back(&layout);
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
static absl::Status EnsureExecutableOutputLayoutsPopulated(
|
||||||
|
PJRT_Executable* executable) {
|
||||||
|
absl::MutexLock lock(executable->mutex);
|
||||||
|
if (!executable->out_layouts_ran) {
|
||||||
|
TF_RETURN_IF_ERROR(PopulateExecutableOutputLayouts(executable));
|
||||||
|
executable->out_layouts_ran = true;
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
static absl::Status PopulateExecutableOutputMemoryKinds(
|
static absl::Status PopulateExecutableOutputMemoryKinds(
|
||||||
PJRT_Executable* executable) {
|
PJRT_Executable* executable) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
|
@ -2692,6 +2719,19 @@ PJRT_Error* PJRT_Layouts_PJRT_Topology_GetDefaultLayout(
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PJRT_Error* PJRT_Layouts_PJRT_Executable_GetOutputLayouts(
|
||||||
|
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args* args) {
|
||||||
|
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
|
||||||
|
"PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args",
|
||||||
|
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args_STRUCT_SIZE,
|
||||||
|
args->struct_size));
|
||||||
|
PJRT_RETURN_IF_ERROR(
|
||||||
|
EnsureExecutableOutputLayoutsPopulated(args->executable));
|
||||||
|
args->num_outputs = args->executable->out_layouts_pointers.size();
|
||||||
|
args->layouts = args->executable->out_layouts_pointers.data();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
static std::vector<PJRT_NamedValue> PopulatePjrtAttributes(
|
static std::vector<PJRT_NamedValue> PopulatePjrtAttributes(
|
||||||
const absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>&
|
const absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>&
|
||||||
attributes) {
|
attributes) {
|
||||||
|
|
@ -3106,6 +3146,8 @@ PJRT_Layouts_Extension CreateLayoutsExtension(PJRT_Extension_Base* next) {
|
||||||
pjrt::PJRT_Layouts_PJRT_Buffer_MemoryLayout,
|
pjrt::PJRT_Layouts_PJRT_Buffer_MemoryLayout,
|
||||||
/*PJRT_Layouts_PJRT_Topology_GetDefaultLayout=*/
|
/*PJRT_Layouts_PJRT_Topology_GetDefaultLayout=*/
|
||||||
&PJRT_Layouts_PJRT_Topology_GetDefaultLayout,
|
&PJRT_Layouts_PJRT_Topology_GetDefaultLayout,
|
||||||
|
/*PJRT_Layouts_PJRT_Executable_GetOutputLayouts=*/
|
||||||
|
&PJRT_Layouts_PJRT_Executable_GetOutputLayouts,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -158,6 +158,10 @@ struct PJRT_Executable {
|
||||||
std::vector<int64_t> out_dimensions;
|
std::vector<int64_t> out_dimensions;
|
||||||
std::vector<size_t> out_dimension_sizes;
|
std::vector<size_t> out_dimension_sizes;
|
||||||
|
|
||||||
|
bool out_layouts_ran ABSL_GUARDED_BY(mutex) = false;
|
||||||
|
std::vector<PJRT_Layouts_MemoryLayout> out_layouts;
|
||||||
|
std::vector<PJRT_Layouts_MemoryLayout*> out_layouts_pointers;
|
||||||
|
|
||||||
explicit PJRT_Executable(std::shared_ptr<xla::PjRtExecutable> executable);
|
explicit PJRT_Executable(std::shared_ptr<xla::PjRtExecutable> executable);
|
||||||
explicit PJRT_Executable(xla::PjRtExecutable* executable);
|
explicit PJRT_Executable(xla::PjRtExecutable* executable);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1725,6 +1725,57 @@ PjRtCApiExecutable::GetOutputDimensions() const {
|
||||||
return std::vector<std::vector<DimensionVector>>{std::move(out)};
|
return std::vector<std::vector<DimensionVector>>{std::move(out)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
|
||||||
|
PjRtCApiExecutable::GetOutputLayouts() const {
|
||||||
|
const PJRT_Api* c_api = pjrt_c_api();
|
||||||
|
PJRT_Layouts_Extension* extension =
|
||||||
|
pjrt::FindExtension<PJRT_Layouts_Extension>(
|
||||||
|
c_api, PJRT_Extension_Type::PJRT_Extension_Type_Layouts);
|
||||||
|
if (extension == nullptr ||
|
||||||
|
extension->PJRT_Layouts_MemoryLayout_Serialize == nullptr ||
|
||||||
|
extension->PJRT_Layouts_PJRT_Executable_GetOutputLayouts == nullptr) {
|
||||||
|
// If we can't find PJRT_Layouts_PJRT_Executable_GetOutputLayouts support,
|
||||||
|
// fall back to the default implementation.
|
||||||
|
return this->PjRtExecutable::GetOutputLayouts();
|
||||||
|
}
|
||||||
|
|
||||||
|
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args args;
|
||||||
|
args.struct_size =
|
||||||
|
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args_STRUCT_SIZE;
|
||||||
|
args.extension_start = nullptr;
|
||||||
|
args.executable = c_executable();
|
||||||
|
RETURN_STATUS_IF_PJRT_ERROR(
|
||||||
|
extension->PJRT_Layouts_PJRT_Executable_GetOutputLayouts(&args), c_api);
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<const PjRtLayout>> layouts;
|
||||||
|
layouts.reserve(args.num_outputs);
|
||||||
|
for (int i = 0; i < args.num_outputs; ++i) {
|
||||||
|
// TODO(b/343274093): returns a PjRtLayout that wraps a C API layout
|
||||||
|
// directly instead of de/serializing into an xla::Layout.
|
||||||
|
PJRT_Layouts_MemoryLayout_Serialize_Args serialize_args;
|
||||||
|
serialize_args.struct_size =
|
||||||
|
PJRT_Layouts_MemoryLayout_Serialize_Args_STRUCT_SIZE;
|
||||||
|
serialize_args.extension_start = nullptr;
|
||||||
|
serialize_args.layout = args.layouts[i];
|
||||||
|
pjrt::LogFatalIfPjrtError(
|
||||||
|
extension->PJRT_Layouts_MemoryLayout_Serialize(&serialize_args), c_api);
|
||||||
|
|
||||||
|
// Clean up `PJRT_Layouts_SerializedLayout`.
|
||||||
|
absl::Cleanup cleanup = [&serialize_args] {
|
||||||
|
serialize_args.serialized_layout_deleter(
|
||||||
|
serialize_args.serialized_layout);
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string serialized_layout(serialize_args.serialized_bytes,
|
||||||
|
serialize_args.serialized_bytes_size);
|
||||||
|
absl::StatusOr<std::shared_ptr<const PjRtLayout>> pjrt_layout =
|
||||||
|
PjRtLayout::Deserialize(serialized_layout);
|
||||||
|
TF_CHECK_OK(pjrt_layout.status());
|
||||||
|
layouts.push_back(*std::move(pjrt_layout));
|
||||||
|
}
|
||||||
|
return layouts;
|
||||||
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
|
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
|
||||||
PjRtCApiExecutable::GetOutputMemoryKinds() const {
|
PjRtCApiExecutable::GetOutputMemoryKinds() const {
|
||||||
PJRT_Executable_OutputMemoryKinds_Args args;
|
PJRT_Executable_OutputMemoryKinds_Args args;
|
||||||
|
|
|
||||||
|
|
@ -590,6 +590,9 @@ class PjRtCApiExecutable : public PjRtExecutable {
|
||||||
absl::StatusOr<std::vector<std::vector<DimensionVector>>>
|
absl::StatusOr<std::vector<std::vector<DimensionVector>>>
|
||||||
GetOutputDimensions() const override;
|
GetOutputDimensions() const override;
|
||||||
|
|
||||||
|
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
|
||||||
|
GetOutputLayouts() const override;
|
||||||
|
|
||||||
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
|
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
|
||||||
GetOutputMemoryKinds() const override;
|
GetOutputMemoryKinds() const override;
|
||||||
|
|
||||||
|
|
@ -671,6 +674,11 @@ class PjRtCApiLoadedExecutable : public PjRtLoadedExecutable {
|
||||||
return executable_->GetOutputDimensions();
|
return executable_->GetOutputDimensions();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
|
||||||
|
GetOutputLayouts() const override {
|
||||||
|
return executable_->GetOutputLayouts();
|
||||||
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
|
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
|
||||||
GetOutputMemoryKinds() const override {
|
GetOutputMemoryKinds() const override {
|
||||||
return executable_->GetOutputMemoryKinds();
|
return executable_->GetOutputMemoryKinds();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user