mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[PjRt-IFRT] Temporary workaround for output layout handling
PjRt-IFRT directly or indirectly fetched optimized HLO to get the output layout mode and output layouts. This seems to introduce a regression in some jobs that use PJRT C API and have a too large serialized HLO (> 2 GiB). As a workaround, PjRt-IFRT gracefully handles output layout mode and layout discovery errors, and falls back to concrete layouts that are directly obtained from output `PjRtBuffer`s, should give the same behavior before/after the default layout handling change. Further changes will follow to discover default layout modes and layouts without going through `PjRtLoadedExecutable::GetHloModules()`. PiperOrigin-RevId: 820785277
This commit is contained in:
parent
b07145966f
commit
05101b9755
|
|
@ -296,17 +296,34 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
|
|||
TF_ASSIGN_OR_RETURN(
|
||||
auto result_memory_kinds,
|
||||
GetFirstModuleOutputMemoryKinds(pjrt_loaded_executable.get()));
|
||||
TF_ASSIGN_OR_RETURN(auto hlo_modules,
|
||||
pjrt_loaded_executable->GetHloModules());
|
||||
if (hlo_modules.empty()) {
|
||||
return FailedPrecondition("Requires at least one HloModule.");
|
||||
// Obtaining output layout modes and output layouts directly from
|
||||
// `PjRtLoadedExecutable` may fail because the currently PjRt implementations
|
||||
// often fetch and serialize the optimized HLO. For now, we gracefully
|
||||
// handle it by omitting output layouts at creation time and using output
|
||||
// `PjRtBuffer`'s concrete layouts.
|
||||
// TODO(hyeontaek): Add a way to obtain output layout modes and
|
||||
// `PjRtLoadedExecutable::GetOutputLayouts()` without causing the optimized
|
||||
// HLO to be serialized and fetched.
|
||||
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
|
||||
output_layouts;
|
||||
absl::StatusOr<std::vector<std::shared_ptr<HloModule>>> hlo_modules =
|
||||
pjrt_loaded_executable->GetHloModules();
|
||||
if (hlo_modules.ok()) {
|
||||
if (hlo_modules->empty()) {
|
||||
return FailedPrecondition("Requires at least one HloModule.");
|
||||
}
|
||||
absl::StatusOr<std::vector<xla::LayoutMode>> output_layout_modes =
|
||||
GetLayoutModes(*hlo_modules->front(), "out_layout_modes",
|
||||
result_element_types.size());
|
||||
if (output_layout_modes.ok()) {
|
||||
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
|
||||
first_module_output_layouts = GetFirstModuleOutputLayouts(
|
||||
pjrt_loaded_executable.get(), *output_layout_modes);
|
||||
if (first_module_output_layouts.ok()) {
|
||||
output_layouts = *std::move(first_module_output_layouts);
|
||||
}
|
||||
}
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(std::vector<xla::LayoutMode> output_layout_modes,
|
||||
GetLayoutModes(*hlo_modules.front(), "out_layout_modes",
|
||||
result_element_types.size()));
|
||||
TF_ASSIGN_OR_RETURN(auto output_layouts,
|
||||
GetFirstModuleOutputLayouts(pjrt_loaded_executable.get(),
|
||||
output_layout_modes));
|
||||
return CreateInternal(client, std::move(pjrt_loaded_executable),
|
||||
result_element_types, result_dimensions,
|
||||
/*result_hlo_sharding=*/std::nullopt,
|
||||
|
|
@ -352,8 +369,8 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
|
|||
// will use the MLIR as scratch space, or possibly even deallocate it.
|
||||
TF_ASSIGN_OR_RETURN(const std::vector<xla::Shape> result_shapes,
|
||||
ResultShapesOfModule(module));
|
||||
TF_ASSIGN_OR_RETURN(const std::vector<xla::LayoutMode> output_layout_modes,
|
||||
GetOutputLayoutModes(module));
|
||||
absl::StatusOr<std::vector<xla::LayoutMode>> output_layout_modes =
|
||||
GetOutputLayoutModes(module);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto pjrt_loaded_executable,
|
||||
client->pjrt_client()->CompileAndLoad(
|
||||
|
|
@ -372,9 +389,24 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
|
|||
TF_ASSIGN_OR_RETURN(
|
||||
auto result_memory_kinds,
|
||||
GetFirstModuleOutputMemoryKinds(pjrt_loaded_executable.get()));
|
||||
TF_ASSIGN_OR_RETURN(auto output_layouts,
|
||||
GetFirstModuleOutputLayouts(
|
||||
pjrt_loaded_executable.get(), output_layout_modes));
|
||||
// Obtaining output layout modes and output layouts directly from
|
||||
// `PjRtLoadedExecutable` may fail because the currently PjRt
|
||||
// implementations often fetch and serialize the optimized HLO. For now, we
|
||||
// gracefully handle it by omitting output layouts at creation time and
|
||||
// using output `PjRtBuffer`'s concrete layouts.
|
||||
// TODO(hyeontaek): Add a way to obtain output layout modes and
|
||||
// `PjRtLoadedExecutable::GetOutputLayouts()` without causing the optimized
|
||||
// HLO to be serialized and fetched.
|
||||
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
|
||||
output_layouts;
|
||||
if (output_layout_modes.ok()) {
|
||||
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
|
||||
first_module_output_layouts = GetFirstModuleOutputLayouts(
|
||||
pjrt_loaded_executable.get(), *output_layout_modes);
|
||||
if (first_module_output_layouts.ok()) {
|
||||
output_layouts = *std::move(first_module_output_layouts);
|
||||
}
|
||||
}
|
||||
return CreateInternal(client, std::move(pjrt_loaded_executable),
|
||||
result_element_types, result_dimensions,
|
||||
/*result_hlo_sharding=*/std::nullopt,
|
||||
|
|
@ -405,9 +437,24 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
|
|||
TF_ASSIGN_OR_RETURN(
|
||||
auto result_memory_kinds,
|
||||
GetFirstModuleOutputMemoryKinds(pjrt_loaded_executable.get()));
|
||||
TF_ASSIGN_OR_RETURN(auto output_layouts,
|
||||
GetFirstModuleOutputLayouts(
|
||||
pjrt_loaded_executable.get(), output_layout_modes));
|
||||
// Obtaining output layout modes and output layouts directly from
|
||||
// `PjRtLoadedExecutable` may fail because the currently PjRt
|
||||
// implementations often fetch and serialize the optimized HLO. For now, we
|
||||
// gracefully handle it by omitting output layouts at creation time and
|
||||
// using output `PjRtBuffer`'s concrete layouts.
|
||||
// TODO(hyeontaek): Add a way to obtain output layout modes and
|
||||
// `PjRtLoadedExecutable::GetOutputLayouts()` without causing the optimized
|
||||
// HLO to be serialized and fetched.
|
||||
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
|
||||
output_layouts;
|
||||
if (output_layout_modes.ok()) {
|
||||
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
|
||||
first_module_output_layouts = GetFirstModuleOutputLayouts(
|
||||
pjrt_loaded_executable.get(), *output_layout_modes);
|
||||
if (first_module_output_layouts.ok()) {
|
||||
output_layouts = *std::move(first_module_output_layouts);
|
||||
}
|
||||
}
|
||||
return CreateInternal(
|
||||
client, std::move(pjrt_loaded_executable),
|
||||
shape_partial_info.element_types, shape_partial_info.dimensions,
|
||||
|
|
@ -423,7 +470,8 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::CreateInternal(
|
|||
absl::Span<const xla::DimensionVector> result_dimensions,
|
||||
const std::optional<xla::HloSharding>& result_hlo_sharding,
|
||||
const std::optional<std::vector<absl::string_view>>& result_memory_kinds,
|
||||
const std::vector<std::shared_ptr<const xla::PjRtLayout>>& output_layouts,
|
||||
const std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>&
|
||||
output_layouts,
|
||||
std::vector<tsl::RCReference<LoadedHostCallback>> loaded_host_callbacks,
|
||||
DeviceListRef executable_devices) {
|
||||
// For jit(pmap(...)), the device assignment (passed as `executable_devices`)
|
||||
|
|
@ -596,7 +644,8 @@ PjRtLoadedExecutable::PjRtLoadedExecutable(
|
|||
host_send_recv_callbacks,
|
||||
std::vector<DType> output_dtypes, std::vector<Shape> output_shapes,
|
||||
std::vector<ShardingRef> output_shardings,
|
||||
std::vector<std::shared_ptr<const xla::PjRtLayout>> output_layouts)
|
||||
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
|
||||
output_layouts)
|
||||
: client_(client),
|
||||
pjrt_loaded_executable_(std::move(pjrt_loaded_executable)),
|
||||
devices_(std::move(devices)),
|
||||
|
|
@ -812,6 +861,41 @@ PjRtLoadedExecutable::Execute(absl::Span<ArrayRef> args,
|
|||
// memory_kind shares the same Sharding object.
|
||||
absl::flat_hash_map<MemoryKind, ShardingRef> single_device_shardings;
|
||||
|
||||
std::vector<std::shared_ptr<const xla::PjRtLayout>> layouts;
|
||||
layouts.reserve(num_outputs);
|
||||
if (output_layouts_.has_value()) {
|
||||
// TODO(hyeontaek): Once we can get `output_layouts_` reliably, only keep
|
||||
// this path.
|
||||
layouts = *output_layouts_;
|
||||
} else if (!pjrt_outputs.empty()) {
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
auto layout = output_dtypes_[i].kind() == xla::ifrt::DType::kToken
|
||||
? std::make_shared<xla::PjRtLayout>(xla::Layout())
|
||||
: pjrt_outputs.front()[i]->layout();
|
||||
layouts.push_back(std::move(layout));
|
||||
}
|
||||
} else {
|
||||
auto maybe_layouts = GetOutputLayouts();
|
||||
if (absl::IsUnimplemented(maybe_layouts.status())) {
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
std::shared_ptr<const xla::PjRtLayout> layout;
|
||||
if (output_dtypes_[i].kind() == xla::ifrt::DType::kToken) {
|
||||
layout = std::make_shared<xla::PjRtLayout>(xla::Layout());
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(layout,
|
||||
client_->GetDefaultPjRtLayout(
|
||||
output_dtypes_[i], output_shapes_[i].dims(),
|
||||
devices_->devices().front(),
|
||||
output_shardings_[i]->memory_kind()));
|
||||
}
|
||||
layouts.push_back(std::move(layout));
|
||||
}
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(maybe_layouts.status());
|
||||
layouts = *std::move(maybe_layouts);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
PjRtArray::PjRtBuffers buffers;
|
||||
buffers.reserve(num_computations);
|
||||
|
|
@ -852,7 +936,7 @@ PjRtLoadedExecutable::Execute(absl::Span<ArrayRef> args,
|
|||
}
|
||||
outputs.push_back(*PjRtArray::Create(
|
||||
client_, output_dtypes_[i], output_shapes_[i], *std::move(sharding),
|
||||
std::move(buffers), output_layouts_[i]));
|
||||
std::move(buffers), std::move(layouts[i])));
|
||||
}
|
||||
|
||||
ExecuteResult result;
|
||||
|
|
|
|||
|
|
@ -339,7 +339,8 @@ class PjRtLoadedExecutable final
|
|||
absl::Span<const xla::DimensionVector> result_dimensions,
|
||||
const std::optional<xla::HloSharding>& result_hlo_sharding,
|
||||
const std::optional<std::vector<absl::string_view>>& result_memory_kinds,
|
||||
const std::vector<std::shared_ptr<const xla::PjRtLayout>>& output_layouts,
|
||||
const std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>&
|
||||
output_layouts,
|
||||
std::vector<tsl::RCReference<LoadedHostCallback>> loaded_host_callbacks,
|
||||
DeviceListRef executable_devices);
|
||||
|
||||
|
|
@ -353,7 +354,8 @@ class PjRtLoadedExecutable final
|
|||
host_send_recv_callbacks,
|
||||
std::vector<DType> output_dtypes, std::vector<Shape> output_shapes,
|
||||
std::vector<ShardingRef> output_shardings,
|
||||
std::vector<std::shared_ptr<const xla::PjRtLayout>> output_layouts);
|
||||
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
|
||||
output_layouts);
|
||||
|
||||
PjRtClient* client_;
|
||||
std::shared_ptr<xla::PjRtLoadedExecutable> pjrt_loaded_executable_;
|
||||
|
|
@ -372,7 +374,8 @@ class PjRtLoadedExecutable final
|
|||
std::vector<DType> output_dtypes_;
|
||||
std::vector<Shape> output_shapes_;
|
||||
std::vector<ShardingRef> output_shardings_;
|
||||
std::vector<std::shared_ptr<const xla::PjRtLayout>> output_layouts_;
|
||||
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
|
||||
output_layouts_;
|
||||
const xla::ifrt::UserContextRef user_context_;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user