mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
parent
aeda5dabd4
commit
70111bb38f
4
third_party/xla/xla/pjrt/c/CHANGELOG.md
vendored
4
third_party/xla/xla/pjrt/c/CHANGELOG.md
vendored
|
|
@ -1,5 +1,9 @@
|
|||
# PJRT C API changelog
|
||||
|
||||
## 0.81
|
||||
|
||||
* Added `PJRT_Layouts_PJRT_Executable_GetOutputLayouts`.
|
||||
|
||||
## 0.80
|
||||
|
||||
* Added `PJRT_Extension_Type::PJRT_Extension_Type_HostAllocator`.
|
||||
|
|
|
|||
2
third_party/xla/xla/pjrt/c/pjrt_c_api.h
vendored
2
third_party/xla/xla/pjrt/c/pjrt_c_api.h
vendored
|
|
@ -103,7 +103,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next);
|
|||
// Changes include:
|
||||
// * Adding a new field to the PJRT_Api or argument structs
|
||||
// * Renaming a method or argument (doesn't affect ABI)
|
||||
#define PJRT_API_MINOR 80
|
||||
#define PJRT_API_MINOR 81
|
||||
|
||||
// The plugin should set the major_version and minor_version of
|
||||
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
|
||||
|
|
|
|||
|
|
@ -26,8 +26,9 @@ extern "C" {
|
|||
#endif
|
||||
|
||||
// This extension provides capabilities around custom on-device memory layouts
|
||||
// for PJRT_Buffers. The extension is both optional and experimental, meaning
|
||||
// ABI-breaking and other incompatible changes may be introduced at any time.
|
||||
// for PJRT_Buffers and PJRT_Executables. The extension is both optional and
|
||||
// experimental, meaning ABI-breaking and other incompatible changes may be
|
||||
// introduced at any time.
|
||||
//
|
||||
// If this extension is provided, JAX and possibly other frameworks will assume
|
||||
// that the compiler MLIR input can contain "mhlo.layout_mode" attributes on
|
||||
|
|
@ -36,7 +37,7 @@ extern "C" {
|
|||
// https://github.com/openxla/xla/blob/main/xla/pjrt/layout_mode.h for more
|
||||
// details.
|
||||
|
||||
#define PJRT_API_LAYOUTS_EXTENSION_VERSION 2
|
||||
#define PJRT_API_LAYOUTS_EXTENSION_VERSION 3
|
||||
|
||||
// -------------------------------- Data types ---------------------------------
|
||||
|
||||
|
|
@ -124,6 +125,23 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_PJRT_Topology_GetDefaultLayout_Args,
|
|||
typedef PJRT_Error* PJRT_Layouts_PJRT_Topology_GetDefaultLayout(
|
||||
PJRT_Layouts_PJRT_Topology_GetDefaultLayout_Args* args);
|
||||
|
||||
// Returns output layouts for an executable.
|
||||
struct PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args {
|
||||
size_t struct_size;
|
||||
PJRT_Extension_Base* extension_start;
|
||||
PJRT_Executable* executable;
|
||||
size_t num_outputs; // out
|
||||
// Layout data is owned by and has the lifetime of `executable`.
|
||||
// Has length `num_outputs`.
|
||||
PJRT_Layouts_MemoryLayout** layouts; // out
|
||||
};
|
||||
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args,
|
||||
layouts);
|
||||
|
||||
// Returns a list of layouts for executable outputs. Each output has a layout.
|
||||
typedef PJRT_Error* PJRT_Layouts_PJRT_Executable_GetOutputLayouts(
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args* args);
|
||||
|
||||
// --------------------------- Extension entrypoint ----------------------------
|
||||
|
||||
typedef struct PJRT_Layouts_Extension {
|
||||
|
|
@ -136,9 +154,11 @@ typedef struct PJRT_Layouts_Extension {
|
|||
PJRT_Layouts_PJRT_Buffer_MemoryLayout* PJRT_Layouts_PJRT_Buffer_MemoryLayout;
|
||||
PJRT_Layouts_PJRT_Topology_GetDefaultLayout*
|
||||
PJRT_Layouts_PJRT_Topology_GetDefaultLayout;
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts*
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts;
|
||||
} PJRT_Layouts_Extension;
|
||||
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_Extension,
|
||||
PJRT_Layouts_PJRT_Topology_GetDefaultLayout);
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -219,6 +219,33 @@ static absl::Status EnsureExecutableOutputDimensionsPopulated(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
static absl::Status PopulateExecutableOutputLayouts(
|
||||
PJRT_Executable* executable) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<std::shared_ptr<const xla::PjRtLayout>> cpp_out_layouts,
|
||||
executable->get()->GetOutputLayouts());
|
||||
executable->out_layouts.reserve(cpp_out_layouts.size());
|
||||
executable->out_layouts_pointers.reserve(cpp_out_layouts.size());
|
||||
for (std::shared_ptr<const xla::PjRtLayout>& layout : cpp_out_layouts) {
|
||||
executable->out_layouts.push_back(
|
||||
PJRT_Layouts_MemoryLayout{std::move(layout)});
|
||||
}
|
||||
for (PJRT_Layouts_MemoryLayout& layout : executable->out_layouts) {
|
||||
executable->out_layouts_pointers.push_back(&layout);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
static absl::Status EnsureExecutableOutputLayoutsPopulated(
|
||||
PJRT_Executable* executable) {
|
||||
absl::MutexLock lock(executable->mutex);
|
||||
if (!executable->out_layouts_ran) {
|
||||
TF_RETURN_IF_ERROR(PopulateExecutableOutputLayouts(executable));
|
||||
executable->out_layouts_ran = true;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
static absl::Status PopulateExecutableOutputMemoryKinds(
|
||||
PJRT_Executable* executable) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
|
|
@ -2692,6 +2719,19 @@ PJRT_Error* PJRT_Layouts_PJRT_Topology_GetDefaultLayout(
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
PJRT_Error* PJRT_Layouts_PJRT_Executable_GetOutputLayouts(
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args* args) {
|
||||
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
|
||||
"PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args",
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args_STRUCT_SIZE,
|
||||
args->struct_size));
|
||||
PJRT_RETURN_IF_ERROR(
|
||||
EnsureExecutableOutputLayoutsPopulated(args->executable));
|
||||
args->num_outputs = args->executable->out_layouts_pointers.size();
|
||||
args->layouts = args->executable->out_layouts_pointers.data();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static std::vector<PJRT_NamedValue> PopulatePjrtAttributes(
|
||||
const absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>&
|
||||
attributes) {
|
||||
|
|
@ -3106,6 +3146,8 @@ PJRT_Layouts_Extension CreateLayoutsExtension(PJRT_Extension_Base* next) {
|
|||
pjrt::PJRT_Layouts_PJRT_Buffer_MemoryLayout,
|
||||
/*PJRT_Layouts_PJRT_Topology_GetDefaultLayout=*/
|
||||
&PJRT_Layouts_PJRT_Topology_GetDefaultLayout,
|
||||
/*PJRT_Layouts_PJRT_Executable_GetOutputLayouts=*/
|
||||
&PJRT_Layouts_PJRT_Executable_GetOutputLayouts,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -158,6 +158,10 @@ struct PJRT_Executable {
|
|||
std::vector<int64_t> out_dimensions;
|
||||
std::vector<size_t> out_dimension_sizes;
|
||||
|
||||
bool out_layouts_ran ABSL_GUARDED_BY(mutex) = false;
|
||||
std::vector<PJRT_Layouts_MemoryLayout> out_layouts;
|
||||
std::vector<PJRT_Layouts_MemoryLayout*> out_layouts_pointers;
|
||||
|
||||
explicit PJRT_Executable(std::shared_ptr<xla::PjRtExecutable> executable);
|
||||
explicit PJRT_Executable(xla::PjRtExecutable* executable);
|
||||
|
||||
|
|
|
|||
|
|
@ -1725,6 +1725,63 @@ PjRtCApiExecutable::GetOutputDimensions() const {
|
|||
return std::vector<std::vector<DimensionVector>>{std::move(out)};
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
|
||||
PjRtCApiExecutable::GetOutputLayouts() const {
|
||||
const PJRT_Api* c_api = pjrt_c_api();
|
||||
if (c_api->pjrt_api_version.major_version == 0 &&
|
||||
c_api->pjrt_api_version.minor_version < 81) {
|
||||
// If the PJRT C API version is too old, fall back to the default
|
||||
// implementation.
|
||||
return this->PjRtExecutable::GetOutputLayouts();
|
||||
}
|
||||
PJRT_Layouts_Extension* extension =
|
||||
pjrt::FindExtension<PJRT_Layouts_Extension>(
|
||||
c_api, PJRT_Extension_Type::PJRT_Extension_Type_Layouts);
|
||||
if (extension == nullptr ||
|
||||
extension->PJRT_Layouts_MemoryLayout_Serialize == nullptr ||
|
||||
extension->PJRT_Layouts_PJRT_Executable_GetOutputLayouts == nullptr) {
|
||||
// If we can't find PJRT_Layouts_PJRT_Executable_GetOutputLayouts support,
|
||||
// fall back to the default implementation.
|
||||
return this->PjRtExecutable::GetOutputLayouts();
|
||||
}
|
||||
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args args;
|
||||
args.struct_size =
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args_STRUCT_SIZE;
|
||||
args.extension_start = nullptr;
|
||||
args.executable = c_executable();
|
||||
RETURN_STATUS_IF_PJRT_ERROR(
|
||||
extension->PJRT_Layouts_PJRT_Executable_GetOutputLayouts(&args), c_api);
|
||||
|
||||
std::vector<std::shared_ptr<const PjRtLayout>> layouts;
|
||||
layouts.reserve(args.num_outputs);
|
||||
for (int i = 0; i < args.num_outputs; ++i) {
|
||||
// TODO(b/343274093): returns a PjRtLayout that wraps a C API layout
|
||||
// directly instead of de/serializing into an xla::Layout.
|
||||
PJRT_Layouts_MemoryLayout_Serialize_Args serialize_args;
|
||||
serialize_args.struct_size =
|
||||
PJRT_Layouts_MemoryLayout_Serialize_Args_STRUCT_SIZE;
|
||||
serialize_args.extension_start = nullptr;
|
||||
serialize_args.layout = args.layouts[i];
|
||||
pjrt::LogFatalIfPjrtError(
|
||||
extension->PJRT_Layouts_MemoryLayout_Serialize(&serialize_args), c_api);
|
||||
|
||||
// Clean up `PJRT_Layouts_SerializedLayout`.
|
||||
absl::Cleanup cleanup = [&serialize_args] {
|
||||
serialize_args.serialized_layout_deleter(
|
||||
serialize_args.serialized_layout);
|
||||
};
|
||||
|
||||
std::string serialized_layout(serialize_args.serialized_bytes,
|
||||
serialize_args.serialized_bytes_size);
|
||||
absl::StatusOr<std::shared_ptr<const PjRtLayout>> pjrt_layout =
|
||||
PjRtLayout::Deserialize(serialized_layout);
|
||||
TF_CHECK_OK(pjrt_layout.status());
|
||||
layouts.push_back(*std::move(pjrt_layout));
|
||||
}
|
||||
return layouts;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
|
||||
PjRtCApiExecutable::GetOutputMemoryKinds() const {
|
||||
PJRT_Executable_OutputMemoryKinds_Args args;
|
||||
|
|
|
|||
|
|
@ -590,6 +590,9 @@ class PjRtCApiExecutable : public PjRtExecutable {
|
|||
absl::StatusOr<std::vector<std::vector<DimensionVector>>>
|
||||
GetOutputDimensions() const override;
|
||||
|
||||
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
|
||||
GetOutputLayouts() const override;
|
||||
|
||||
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
|
||||
GetOutputMemoryKinds() const override;
|
||||
|
||||
|
|
@ -671,6 +674,11 @@ class PjRtCApiLoadedExecutable : public PjRtLoadedExecutable {
|
|||
return executable_->GetOutputDimensions();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
|
||||
GetOutputLayouts() const override {
|
||||
return executable_->GetOutputLayouts();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
|
||||
GetOutputMemoryKinds() const override {
|
||||
return executable_->GetOutputMemoryKinds();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user