[PjRt-IFRT] Temporary workaround for output layout handling

PjRt-IFRT directly or indirectly fetched optimized HLO to get the output
layout mode and output layouts. This seems to introduce a regression in
some jobs that use PJRT C API and have a too large serialized HLO (> 2 GiB).

As a workaround, PjRt-IFRT gracefully handles output layout mode and
layout discovery errors, and falls back to concrete layouts that are
directly obtained from output `PjRtBuffer`s, should give the same
behavior before/after the default layout handling change.

Further changes will follow to discover default layout modes and layouts
without going through `PjRtLoadedExecutable::GetHloModules()`.

PiperOrigin-RevId: 820785277
This commit is contained in:
Hyeontaek Lim 2025-10-17 12:14:54 -07:00 committed by TensorFlower Gardener
parent b07145966f
commit 05101b9755
2 changed files with 111 additions and 24 deletions

View File

@ -296,17 +296,34 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
TF_ASSIGN_OR_RETURN(
auto result_memory_kinds,
GetFirstModuleOutputMemoryKinds(pjrt_loaded_executable.get()));
TF_ASSIGN_OR_RETURN(auto hlo_modules,
pjrt_loaded_executable->GetHloModules());
if (hlo_modules.empty()) {
return FailedPrecondition("Requires at least one HloModule.");
// Obtaining output layout modes and output layouts directly from
// `PjRtLoadedExecutable` may fail because the currently PjRt implementations
// often fetch and serialize the optimized HLO. For now, we gracefully
// handle it by omitting output layouts at creation time and using output
// `PjRtBuffer`'s concrete layouts.
// TODO(hyeontaek): Add a way to obtain output layout modes and
// `PjRtLoadedExecutable::GetOutputLayouts()` without causing the optimized
// HLO to be serialized and fetched.
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
output_layouts;
absl::StatusOr<std::vector<std::shared_ptr<HloModule>>> hlo_modules =
pjrt_loaded_executable->GetHloModules();
if (hlo_modules.ok()) {
if (hlo_modules->empty()) {
return FailedPrecondition("Requires at least one HloModule.");
}
absl::StatusOr<std::vector<xla::LayoutMode>> output_layout_modes =
GetLayoutModes(*hlo_modules->front(), "out_layout_modes",
result_element_types.size());
if (output_layout_modes.ok()) {
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
first_module_output_layouts = GetFirstModuleOutputLayouts(
pjrt_loaded_executable.get(), *output_layout_modes);
if (first_module_output_layouts.ok()) {
output_layouts = *std::move(first_module_output_layouts);
}
}
}
TF_ASSIGN_OR_RETURN(std::vector<xla::LayoutMode> output_layout_modes,
GetLayoutModes(*hlo_modules.front(), "out_layout_modes",
result_element_types.size()));
TF_ASSIGN_OR_RETURN(auto output_layouts,
GetFirstModuleOutputLayouts(pjrt_loaded_executable.get(),
output_layout_modes));
return CreateInternal(client, std::move(pjrt_loaded_executable),
result_element_types, result_dimensions,
/*result_hlo_sharding=*/std::nullopt,
@ -352,8 +369,8 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
// will use the MLIR as scratch space, or possibly even deallocate it.
TF_ASSIGN_OR_RETURN(const std::vector<xla::Shape> result_shapes,
ResultShapesOfModule(module));
TF_ASSIGN_OR_RETURN(const std::vector<xla::LayoutMode> output_layout_modes,
GetOutputLayoutModes(module));
absl::StatusOr<std::vector<xla::LayoutMode>> output_layout_modes =
GetOutputLayoutModes(module);
TF_ASSIGN_OR_RETURN(auto pjrt_loaded_executable,
client->pjrt_client()->CompileAndLoad(
@ -372,9 +389,24 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
TF_ASSIGN_OR_RETURN(
auto result_memory_kinds,
GetFirstModuleOutputMemoryKinds(pjrt_loaded_executable.get()));
TF_ASSIGN_OR_RETURN(auto output_layouts,
GetFirstModuleOutputLayouts(
pjrt_loaded_executable.get(), output_layout_modes));
// Obtaining output layout modes and output layouts directly from
// `PjRtLoadedExecutable` may fail because the currently PjRt
// implementations often fetch and serialize the optimized HLO. For now, we
// gracefully handle it by omitting output layouts at creation time and
// using output `PjRtBuffer`'s concrete layouts.
// TODO(hyeontaek): Add a way to obtain output layout modes and
// `PjRtLoadedExecutable::GetOutputLayouts()` without causing the optimized
// HLO to be serialized and fetched.
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
output_layouts;
if (output_layout_modes.ok()) {
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
first_module_output_layouts = GetFirstModuleOutputLayouts(
pjrt_loaded_executable.get(), *output_layout_modes);
if (first_module_output_layouts.ok()) {
output_layouts = *std::move(first_module_output_layouts);
}
}
return CreateInternal(client, std::move(pjrt_loaded_executable),
result_element_types, result_dimensions,
/*result_hlo_sharding=*/std::nullopt,
@ -405,9 +437,24 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
TF_ASSIGN_OR_RETURN(
auto result_memory_kinds,
GetFirstModuleOutputMemoryKinds(pjrt_loaded_executable.get()));
TF_ASSIGN_OR_RETURN(auto output_layouts,
GetFirstModuleOutputLayouts(
pjrt_loaded_executable.get(), output_layout_modes));
// Obtaining output layout modes and output layouts directly from
// `PjRtLoadedExecutable` may fail because the currently PjRt
// implementations often fetch and serialize the optimized HLO. For now, we
// gracefully handle it by omitting output layouts at creation time and
// using output `PjRtBuffer`'s concrete layouts.
// TODO(hyeontaek): Add a way to obtain output layout modes and
// `PjRtLoadedExecutable::GetOutputLayouts()` without causing the optimized
// HLO to be serialized and fetched.
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
output_layouts;
if (output_layout_modes.ok()) {
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
first_module_output_layouts = GetFirstModuleOutputLayouts(
pjrt_loaded_executable.get(), *output_layout_modes);
if (first_module_output_layouts.ok()) {
output_layouts = *std::move(first_module_output_layouts);
}
}
return CreateInternal(
client, std::move(pjrt_loaded_executable),
shape_partial_info.element_types, shape_partial_info.dimensions,
@ -423,7 +470,8 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::CreateInternal(
absl::Span<const xla::DimensionVector> result_dimensions,
const std::optional<xla::HloSharding>& result_hlo_sharding,
const std::optional<std::vector<absl::string_view>>& result_memory_kinds,
const std::vector<std::shared_ptr<const xla::PjRtLayout>>& output_layouts,
const std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>&
output_layouts,
std::vector<tsl::RCReference<LoadedHostCallback>> loaded_host_callbacks,
DeviceListRef executable_devices) {
// For jit(pmap(...)), the device assignment (passed as `executable_devices`)
@ -596,7 +644,8 @@ PjRtLoadedExecutable::PjRtLoadedExecutable(
host_send_recv_callbacks,
std::vector<DType> output_dtypes, std::vector<Shape> output_shapes,
std::vector<ShardingRef> output_shardings,
std::vector<std::shared_ptr<const xla::PjRtLayout>> output_layouts)
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
output_layouts)
: client_(client),
pjrt_loaded_executable_(std::move(pjrt_loaded_executable)),
devices_(std::move(devices)),
@ -812,6 +861,41 @@ PjRtLoadedExecutable::Execute(absl::Span<ArrayRef> args,
// memory_kind shares the same Sharding object.
absl::flat_hash_map<MemoryKind, ShardingRef> single_device_shardings;
std::vector<std::shared_ptr<const xla::PjRtLayout>> layouts;
layouts.reserve(num_outputs);
if (output_layouts_.has_value()) {
// TODO(hyeontaek): Once we can get `output_layouts_` reliably, only keep
// this path.
layouts = *output_layouts_;
} else if (!pjrt_outputs.empty()) {
for (int i = 0; i < num_outputs; ++i) {
auto layout = output_dtypes_[i].kind() == xla::ifrt::DType::kToken
? std::make_shared<xla::PjRtLayout>(xla::Layout())
: pjrt_outputs.front()[i]->layout();
layouts.push_back(std::move(layout));
}
} else {
auto maybe_layouts = GetOutputLayouts();
if (absl::IsUnimplemented(maybe_layouts.status())) {
for (int i = 0; i < num_outputs; ++i) {
std::shared_ptr<const xla::PjRtLayout> layout;
if (output_dtypes_[i].kind() == xla::ifrt::DType::kToken) {
layout = std::make_shared<xla::PjRtLayout>(xla::Layout());
} else {
TF_ASSIGN_OR_RETURN(layout,
client_->GetDefaultPjRtLayout(
output_dtypes_[i], output_shapes_[i].dims(),
devices_->devices().front(),
output_shardings_[i]->memory_kind()));
}
layouts.push_back(std::move(layout));
}
} else {
TF_RETURN_IF_ERROR(maybe_layouts.status());
layouts = *std::move(maybe_layouts);
}
}
for (int i = 0; i < num_outputs; ++i) {
PjRtArray::PjRtBuffers buffers;
buffers.reserve(num_computations);
@ -852,7 +936,7 @@ PjRtLoadedExecutable::Execute(absl::Span<ArrayRef> args,
}
outputs.push_back(*PjRtArray::Create(
client_, output_dtypes_[i], output_shapes_[i], *std::move(sharding),
std::move(buffers), output_layouts_[i]));
std::move(buffers), std::move(layouts[i])));
}
ExecuteResult result;

View File

@ -339,7 +339,8 @@ class PjRtLoadedExecutable final
absl::Span<const xla::DimensionVector> result_dimensions,
const std::optional<xla::HloSharding>& result_hlo_sharding,
const std::optional<std::vector<absl::string_view>>& result_memory_kinds,
const std::vector<std::shared_ptr<const xla::PjRtLayout>>& output_layouts,
const std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>&
output_layouts,
std::vector<tsl::RCReference<LoadedHostCallback>> loaded_host_callbacks,
DeviceListRef executable_devices);
@ -353,7 +354,8 @@ class PjRtLoadedExecutable final
host_send_recv_callbacks,
std::vector<DType> output_dtypes, std::vector<Shape> output_shapes,
std::vector<ShardingRef> output_shardings,
std::vector<std::shared_ptr<const xla::PjRtLayout>> output_layouts);
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
output_layouts);
PjRtClient* client_;
std::shared_ptr<xla::PjRtLoadedExecutable> pjrt_loaded_executable_;
@ -372,7 +374,8 @@ class PjRtLoadedExecutable final
std::vector<DType> output_dtypes_;
std::vector<Shape> output_shapes_;
std::vector<ShardingRef> output_shardings_;
std::vector<std::shared_ptr<const xla::PjRtLayout>> output_layouts_;
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
output_layouts_;
const xla::ifrt::UserContextRef user_context_;
};