mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
parent
ffc0e052de
commit
16064a6c08
|
|
@ -26,9 +26,8 @@ extern "C" {
|
|||
#endif
|
||||
|
||||
// This extension provides capabilities around custom on-device memory layouts
|
||||
// for PJRT_Buffers and PJRT_Executables. The extension is both optional and
|
||||
// experimental, meaning ABI-breaking and other incompatible changes may be
|
||||
// introduced at any time.
|
||||
// for PJRT_Buffers. The extension is both optional and experimental, meaning
|
||||
// ABI-breaking and other incompatible changes may be introduced at any time.
|
||||
//
|
||||
// If this extension is provided, JAX and possibly other frameworks will assume
|
||||
// that the compiler MLIR input can contain "mhlo.layout_mode" attributes on
|
||||
|
|
@ -37,7 +36,7 @@ extern "C" {
|
|||
// https://github.com/openxla/xla/blob/main/xla/pjrt/layout_mode.h for more
|
||||
// details.
|
||||
|
||||
#define PJRT_API_LAYOUTS_EXTENSION_VERSION 3
|
||||
#define PJRT_API_LAYOUTS_EXTENSION_VERSION 2
|
||||
|
||||
// -------------------------------- Data types ---------------------------------
|
||||
|
||||
|
|
@ -125,23 +124,6 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_PJRT_Topology_GetDefaultLayout_Args,
|
|||
typedef PJRT_Error* PJRT_Layouts_PJRT_Topology_GetDefaultLayout(
|
||||
PJRT_Layouts_PJRT_Topology_GetDefaultLayout_Args* args);
|
||||
|
||||
// Returns output layouts for an executable.
|
||||
struct PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args {
|
||||
size_t struct_size;
|
||||
PJRT_Extension_Base* extension_start;
|
||||
PJRT_Executable* executable;
|
||||
size_t num_outputs; // out
|
||||
// Layout data is owned by and has the lifetime of `executable`.
|
||||
// Has length `num_outputs`.
|
||||
PJRT_Layouts_MemoryLayout** layouts; // out
|
||||
};
|
||||
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args,
|
||||
layouts);
|
||||
|
||||
// Returns a list of layouts for executable outputs. Each output has a layout.
|
||||
typedef PJRT_Error* PJRT_Layouts_PJRT_Executable_GetOutputLayouts(
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args* args);
|
||||
|
||||
// --------------------------- Extension entrypoint ----------------------------
|
||||
|
||||
typedef struct PJRT_Layouts_Extension {
|
||||
|
|
@ -154,11 +136,9 @@ typedef struct PJRT_Layouts_Extension {
|
|||
PJRT_Layouts_PJRT_Buffer_MemoryLayout* PJRT_Layouts_PJRT_Buffer_MemoryLayout;
|
||||
PJRT_Layouts_PJRT_Topology_GetDefaultLayout*
|
||||
PJRT_Layouts_PJRT_Topology_GetDefaultLayout;
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts*
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts;
|
||||
} PJRT_Layouts_Extension;
|
||||
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Layouts_Extension,
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts);
|
||||
PJRT_Layouts_PJRT_Topology_GetDefaultLayout);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -219,33 +219,6 @@ static absl::Status EnsureExecutableOutputDimensionsPopulated(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
static absl::Status PopulateExecutableOutputLayouts(
|
||||
PJRT_Executable* executable) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<std::shared_ptr<const xla::PjRtLayout>> cpp_out_layouts,
|
||||
executable->get()->GetOutputLayouts());
|
||||
executable->out_layouts.reserve(cpp_out_layouts.size());
|
||||
executable->out_layouts_pointers.reserve(cpp_out_layouts.size());
|
||||
for (std::shared_ptr<const xla::PjRtLayout>& layout : cpp_out_layouts) {
|
||||
executable->out_layouts.push_back(
|
||||
PJRT_Layouts_MemoryLayout{std::move(layout)});
|
||||
}
|
||||
for (PJRT_Layouts_MemoryLayout& layout : executable->out_layouts) {
|
||||
executable->out_layouts_pointers.push_back(&layout);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
static absl::Status EnsureExecutableOutputLayoutsPopulated(
|
||||
PJRT_Executable* executable) {
|
||||
absl::MutexLock lock(executable->mutex);
|
||||
if (!executable->out_layouts_ran) {
|
||||
TF_RETURN_IF_ERROR(PopulateExecutableOutputLayouts(executable));
|
||||
executable->out_layouts_ran = true;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
static absl::Status PopulateExecutableOutputMemoryKinds(
|
||||
PJRT_Executable* executable) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
|
|
@ -2719,19 +2692,6 @@ PJRT_Error* PJRT_Layouts_PJRT_Topology_GetDefaultLayout(
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
PJRT_Error* PJRT_Layouts_PJRT_Executable_GetOutputLayouts(
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args* args) {
|
||||
PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
|
||||
"PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args",
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args_STRUCT_SIZE,
|
||||
args->struct_size));
|
||||
PJRT_RETURN_IF_ERROR(
|
||||
EnsureExecutableOutputLayoutsPopulated(args->executable));
|
||||
args->num_outputs = args->executable->out_layouts_pointers.size();
|
||||
args->layouts = args->executable->out_layouts_pointers.data();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static std::vector<PJRT_NamedValue> PopulatePjrtAttributes(
|
||||
const absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute>&
|
||||
attributes) {
|
||||
|
|
@ -3146,8 +3106,6 @@ PJRT_Layouts_Extension CreateLayoutsExtension(PJRT_Extension_Base* next) {
|
|||
pjrt::PJRT_Layouts_PJRT_Buffer_MemoryLayout,
|
||||
/*PJRT_Layouts_PJRT_Topology_GetDefaultLayout=*/
|
||||
&PJRT_Layouts_PJRT_Topology_GetDefaultLayout,
|
||||
/*PJRT_Layouts_PJRT_Executable_GetOutputLayouts=*/
|
||||
&PJRT_Layouts_PJRT_Executable_GetOutputLayouts,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -158,10 +158,6 @@ struct PJRT_Executable {
|
|||
std::vector<int64_t> out_dimensions;
|
||||
std::vector<size_t> out_dimension_sizes;
|
||||
|
||||
bool out_layouts_ran ABSL_GUARDED_BY(mutex) = false;
|
||||
std::vector<PJRT_Layouts_MemoryLayout> out_layouts;
|
||||
std::vector<PJRT_Layouts_MemoryLayout*> out_layouts_pointers;
|
||||
|
||||
explicit PJRT_Executable(std::shared_ptr<xla::PjRtExecutable> executable);
|
||||
explicit PJRT_Executable(xla::PjRtExecutable* executable);
|
||||
|
||||
|
|
|
|||
|
|
@ -1725,57 +1725,6 @@ PjRtCApiExecutable::GetOutputDimensions() const {
|
|||
return std::vector<std::vector<DimensionVector>>{std::move(out)};
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
|
||||
PjRtCApiExecutable::GetOutputLayouts() const {
|
||||
const PJRT_Api* c_api = pjrt_c_api();
|
||||
PJRT_Layouts_Extension* extension =
|
||||
pjrt::FindExtension<PJRT_Layouts_Extension>(
|
||||
c_api, PJRT_Extension_Type::PJRT_Extension_Type_Layouts);
|
||||
if (extension == nullptr ||
|
||||
extension->PJRT_Layouts_MemoryLayout_Serialize == nullptr ||
|
||||
extension->PJRT_Layouts_PJRT_Executable_GetOutputLayouts == nullptr) {
|
||||
// If we can't find PJRT_Layouts_PJRT_Executable_GetOutputLayouts support,
|
||||
// fall back to the default implementation.
|
||||
return this->PjRtExecutable::GetOutputLayouts();
|
||||
}
|
||||
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args args;
|
||||
args.struct_size =
|
||||
PJRT_Layouts_PJRT_Executable_GetOutputLayouts_Args_STRUCT_SIZE;
|
||||
args.extension_start = nullptr;
|
||||
args.executable = c_executable();
|
||||
RETURN_STATUS_IF_PJRT_ERROR(
|
||||
extension->PJRT_Layouts_PJRT_Executable_GetOutputLayouts(&args), c_api);
|
||||
|
||||
std::vector<std::shared_ptr<const PjRtLayout>> layouts;
|
||||
layouts.reserve(args.num_outputs);
|
||||
for (int i = 0; i < args.num_outputs; ++i) {
|
||||
// TODO(b/343274093): returns a PjRtLayout that wraps a C API layout
|
||||
// directly instead of de/serializing into an xla::Layout.
|
||||
PJRT_Layouts_MemoryLayout_Serialize_Args serialize_args;
|
||||
serialize_args.struct_size =
|
||||
PJRT_Layouts_MemoryLayout_Serialize_Args_STRUCT_SIZE;
|
||||
serialize_args.extension_start = nullptr;
|
||||
serialize_args.layout = args.layouts[i];
|
||||
pjrt::LogFatalIfPjrtError(
|
||||
extension->PJRT_Layouts_MemoryLayout_Serialize(&serialize_args), c_api);
|
||||
|
||||
// Clean up `PJRT_Layouts_SerializedLayout`.
|
||||
absl::Cleanup cleanup = [&serialize_args] {
|
||||
serialize_args.serialized_layout_deleter(
|
||||
serialize_args.serialized_layout);
|
||||
};
|
||||
|
||||
std::string serialized_layout(serialize_args.serialized_bytes,
|
||||
serialize_args.serialized_bytes_size);
|
||||
absl::StatusOr<std::shared_ptr<const PjRtLayout>> pjrt_layout =
|
||||
PjRtLayout::Deserialize(serialized_layout);
|
||||
TF_CHECK_OK(pjrt_layout.status());
|
||||
layouts.push_back(*std::move(pjrt_layout));
|
||||
}
|
||||
return layouts;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
|
||||
PjRtCApiExecutable::GetOutputMemoryKinds() const {
|
||||
PJRT_Executable_OutputMemoryKinds_Args args;
|
||||
|
|
|
|||
|
|
@ -590,9 +590,6 @@ class PjRtCApiExecutable : public PjRtExecutable {
|
|||
absl::StatusOr<std::vector<std::vector<DimensionVector>>>
|
||||
GetOutputDimensions() const override;
|
||||
|
||||
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
|
||||
GetOutputLayouts() const override;
|
||||
|
||||
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
|
||||
GetOutputMemoryKinds() const override;
|
||||
|
||||
|
|
@ -674,11 +671,6 @@ class PjRtCApiLoadedExecutable : public PjRtLoadedExecutable {
|
|||
return executable_->GetOutputDimensions();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<std::shared_ptr<const PjRtLayout>>>
|
||||
GetOutputLayouts() const override {
|
||||
return executable_->GetOutputLayouts();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<std::vector<absl::string_view>>>
|
||||
GetOutputMemoryKinds() const override {
|
||||
return executable_->GetOutputMemoryKinds();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user