[PjRt-IFRT] ifrt::PjRtArray::pjrt_layout() uses nullptr to indicate a default layout

PjRt-IFRT now returns a `nullptr` if it knows that the Array layout represents a default layout. The user code previously has been migrated to handle this new behavior gracefully, obtaining a concrete default layout as before.

`ifrt::PjRtArray` creation now request extra information on whether the underlying `PjRtBuffer` is using a custom layout as IFRT tracks the defaultness of array layouts. This information cannot be inferred correctly from `PjRtBuffer` alone because `PjRtBuffer::layout()` only returns a concrete layout. PjRt would mostly work fine today if a default layout is said to be a custom layout, but some strict layout equality check can fail and require more precise information to be supplied.

A few test cases in IFRT ArrayImplTest against PjRt CPU and GPU clients
have been disabled because the output array does not track the
non-default-ness of the layout correctly when
`MakeArraysFromHostBufferShards()` is implemented using
`ClientMakeArraysFromHostBufferShards()`.

PiperOrigin-RevId: 819995407
This commit is contained in:
Hyeontaek Lim 2025-10-15 18:29:47 -07:00 committed by TensorFlower Gardener
parent 0c8f3eab9a
commit 55371dfcb4
14 changed files with 291 additions and 133 deletions

View File

@ -302,10 +302,6 @@ TEST(ArrayImplTest, MakeArrayFromHostBufferDefaultLayout) {
for (Memory* const memory : device->Memories()) {
SCOPED_TRACE(absl::StrCat(memory->Kind()));
TF_ASSERT_OK_AND_ASSIGN(auto default_layout,
client->GetDefaultPjRtLayout(
dtype, shape.dims(), device, memory->Kind()));
TF_ASSERT_OK_AND_ASSIGN(
auto array,
client->MakeArrayFromHostBuffer(
@ -316,9 +312,15 @@ TEST(ArrayImplTest, MakeArrayFromHostBufferDefaultLayout) {
TF_ASSERT_OK(array->GetReadyFuture().Await());
TF_ASSERT_OK_AND_ASSIGN(auto layout, array->pjrt_layout());
ASSERT_NE(layout, nullptr);
// `layout` should be either nullptr or a concrete default layout.
if (layout != nullptr) {
TF_ASSERT_OK_AND_ASSIGN(auto default_layout,
client->GetDefaultPjRtLayout(
dtype, shape.dims(), device, memory->Kind()));
EXPECT_EQ(*layout, *default_layout);
}
}
}
TEST(ArrayImplTest, MakeArrayFromHostBufferAndCopyToHostBuffer) {
@ -1451,12 +1453,14 @@ TEST(ArrayImplTest, CopyPreservesDefaultLayouts) {
TF_ASSERT_OK(array->GetReadyFuture().Await());
TF_ASSERT_OK_AND_ASSIGN(auto src_layout, array->pjrt_layout());
ASSERT_NE(src_layout, nullptr);
// `layout` should be either nullptr or a concrete default layout.
if (src_layout != nullptr) {
TF_ASSERT_OK_AND_ASSIGN(
auto src_default_layout,
client->GetDefaultPjRtLayout(dtype, shape.dims(), device,
src_memory->Kind()));
EXPECT_EQ(*src_layout, *src_default_layout);
}
TF_ASSERT_OK_AND_ASSIGN(
auto new_arrays, client->CopyArrays(absl::MakeSpan(&array, 1),
@ -1464,7 +1468,8 @@ TEST(ArrayImplTest, CopyPreservesDefaultLayouts) {
ArrayCopySemantics::kAlwaysCopy));
ASSERT_THAT(new_arrays, SizeIs(1));
TF_ASSERT_OK_AND_ASSIGN(auto dst_layout, new_arrays[0]->pjrt_layout());
ASSERT_NE(dst_layout, nullptr);
// `layout` should be either nullptr or a concrete default layout.
if (dst_layout != nullptr) {
TF_ASSERT_OK_AND_ASSIGN(
auto dst_default_layout,
client->GetDefaultPjRtLayout(dtype, shape.dims(), device,
@ -1472,6 +1477,7 @@ TEST(ArrayImplTest, CopyPreservesDefaultLayouts) {
EXPECT_EQ(*dst_layout, *dst_default_layout);
}
}
}
}
TEST(ArrayImplTest, MakeAndCopyZeroSizedBuffers) {

View File

@ -24,6 +24,14 @@ int main(int argc, char** argv) {
// destination literal.
"ArrayImplTest.MakeArrayFromHostBufferAndCopyToHostBufferWithByteStrides",
// Arrays created using `MakeArraysFromHostBufferShards()` do not indicate
// correct custom layouts even if the given layout is a concrete default
// layout. PjRt-IFRT uses `ClientMakeArraysFromHostBufferShards()`
// internally, which lowers `MakeArraysFromHostBufferShards()` call into
// legacy API calls that do not yet support custom layouts, and thus the
// output arrays only can have default layouts.
"ArrayImplTest.MakeArraysFromHostBufferShardsWithLayout",
// `ShardingParamSharding` does not support serialization yet.
// TODO(b/282757875): Enable the test once IFRT implements
// `ShardingParamShardingSerDes`.

View File

@ -329,6 +329,7 @@ cc_library(
"//xla/hlo/translate/mhlo_to_hlo:type_to_shape",
"//xla/pjrt:host_callback",
"//xla/pjrt:host_memory_spaces",
"//xla/pjrt:layout_mode",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_common",
"//xla/pjrt:pjrt_compiler",

View File

@ -174,16 +174,20 @@ absl::StatusOr<tsl::RCReference<PjRtArray>> PjRtArray::Create(
}
absl::StatusOr<tsl::RCReference<PjRtArray>> PjRtArray::Create(
PjRtCompatibleClient* client, std::shared_ptr<PjRtBuffer> pjrt_buffer) {
PjRtCompatibleClient* client, std::shared_ptr<PjRtBuffer> pjrt_buffer,
bool has_custom_layout) {
TF_ASSIGN_OR_RETURN(auto dtype, ToDType(pjrt_buffer->element_type()));
Shape shape(pjrt_buffer->dimensions());
TF_ASSIGN_OR_RETURN(auto device,
client->LookupPjRtDevice(pjrt_buffer->device()));
auto sharding = SingleDeviceSharding::Create(
device, MakeMemoryKindFromPjRtBuffer(pjrt_buffer.get()));
auto layout = (dtype.kind() == DType::kToken)
std::shared_ptr<const xla::PjRtLayout> layout;
if (has_custom_layout) {
layout = (dtype.kind() == DType::kToken)
? std::make_shared<xla::PjRtLayout>(xla::Layout())
: pjrt_buffer->layout();
}
return tsl::MakeRef<PjRtArray>(
client, dtype, std::move(shape), std::move(sharding),
PjRtBuffers({std::move(pjrt_buffer)}), std::move(layout));
@ -195,7 +199,8 @@ absl::StatusOr<ArrayRef> PjRtArray::FullyReplicatedShard(
return FailedPrecondition(
"FullyReplicatedShard: Array has no addressable shards.");
}
return PjRtArray::Create(client(), GetPjRtBuffer(semantics, 0));
return PjRtArray::Create(client(), GetPjRtBuffer(semantics, 0),
/*has_custom_layout=*/(layout_ != nullptr));
}
std::shared_ptr<PjRtBuffer> PjRtArray::GetPjRtBuffer(
@ -215,7 +220,8 @@ std::shared_ptr<PjRtBuffer> PjRtArray::GetPjRtBuffer(
}
absl::StatusOr<tsl::RCReference<PjRtArray>> PjRtArray::Create(
PjRtCompatibleClient* client, Shape shape, PjRtBuffers pjrt_buffers) {
PjRtCompatibleClient* client, Shape shape, PjRtBuffers pjrt_buffers,
bool has_custom_layout) {
if (pjrt_buffers.empty()) {
return InvalidArgument("PjRtBuffers must be non-empty.");
}
@ -239,14 +245,19 @@ absl::StatusOr<tsl::RCReference<PjRtArray>> PjRtArray::Create(
BasicDeviceList::Create(std::move(devices)), memory_kind,
/*shape=*/shape,
/*shard_shapes=*/shapes);
auto layout = pjrt_buffers.front()->layout();
std::shared_ptr<const xla::PjRtLayout> layout;
if (has_custom_layout) {
layout = (dtype.kind() == DType::kToken)
? std::make_shared<xla::PjRtLayout>(xla::Layout())
: pjrt_buffers.front()->layout();
}
return PjRtArray::Create(client, dtype, std::move(shape), std::move(sharding),
std::move(pjrt_buffers), std::move(layout));
}
absl::StatusOr<tsl::RCReference<PjRtArray>> PjRtArray::Create(
PjRtCompatibleClient* client, DynamicShape dynamic_shape,
PjRtBuffers pjrt_buffers) {
PjRtBuffers pjrt_buffers, bool has_custom_layout) {
if (pjrt_buffers.empty()) {
return InvalidArgument("PjRtBuffers must be non-empty.");
}
@ -276,7 +287,12 @@ absl::StatusOr<tsl::RCReference<PjRtArray>> PjRtArray::Create(
BasicDeviceList::Create(std::move(devices)), memory_kind,
/*dynamic_shape=*/dynamic_shape,
/*shard_dynamic_shapes=*/dynamic_shapes);
auto layout = pjrt_buffers.front()->layout();
std::shared_ptr<const xla::PjRtLayout> layout;
if (has_custom_layout) {
layout = (dtype.kind() == DType::kToken)
? std::make_shared<xla::PjRtLayout>(xla::Layout())
: pjrt_buffers.front()->layout();
}
return PjRtArray::Create(client, dtype, std::move(dynamic_shape),
std::move(sharding), std::move(pjrt_buffers),
std::move(layout));
@ -534,13 +550,22 @@ absl::StatusOr<ArrayRef> PjRtArray::Copy(
if (new_client == nullptr) {
new_client = client_;
}
std::shared_ptr<const xla::PjRtLayout> layout;
static MemoryKind kUnpinnedHostMemoryKind(UnpinnedHostMemorySpace::kKind);
// Unpinned host supports default layouts only; a custom layout would be
// ignored.
// TODO(hyeontaek): This behavior should be informed by the underlying PjRt
// client instead of following a convention.
if (layout_ != nullptr &&
canonicalized_sharding_memory_kind != kUnpinnedHostMemoryKind) {
layout = layout_;
}
return std::visit(
[this, new_client, &new_sharding, &buffers](const auto& shape) {
std::shared_ptr<const xla::PjRtLayout> buffer_layout =
buffers[0]->layout();
[this, new_client, &new_sharding, &buffers,
layout = std::move(layout)](const auto& shape) {
return PjRtArray::Create(new_client, dtype_, shape,
std::move(new_sharding), std::move(buffers),
std::move(buffer_layout));
layout);
},
shape_);
}
@ -609,10 +634,10 @@ absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>> PjRtArray::pjrt_layout()
for (int i = 1; i < pjrt_buffers_.size(); ++i) {
std::shared_ptr<const xla::PjRtLayout> layout_i =
pjrt_buffers_[i]->layout();
DCHECK(*layout_ == *layout_i)
DCHECK(*pjrt_buffers_[0]->layout() == *layout_i)
<< "PjRtArray has mismatched layouts across shards! "
<< "shard 0: " << layout_->ToString() << ", shard " << i << ": "
<< layout_i->ToString();
<< "shard 0: " << pjrt_buffers_[0]->layout()->ToString() << ", shard "
<< i << ": " << layout_i->ToString();
}
#endif
return layout_;

View File

@ -68,34 +68,44 @@ class PjRtArray final
using PjRtBuffers =
absl::InlinedVector<std::shared_ptr<PjRtBuffer>, kPjRtBufferInlineSize>;
// General array construction (with static shape). pjrt_buffers may be empty.
// General array construction (with static shape). `pjrt_buffers` may be
// empty. `layout == nullptr` indicates a default layout.
static absl::StatusOr<tsl::RCReference<PjRtArray>> Create(
PjRtCompatibleClient* client, DType dtype, Shape shape,
ShardingRef sharding, PjRtBuffers pjrt_buffers,
std::shared_ptr<const xla::PjRtLayout> layout);
// General array construction (with dynamic shape). pjrt_buffers may be empty.
// General array construction (with dynamic shape). `pjrt_buffers` may be
// empty. `layout == nullptr` indicates a default layout.
static absl::StatusOr<tsl::RCReference<PjRtArray>> Create(
PjRtCompatibleClient* client, DType dtype, DynamicShape dynamic_shape,
ShardingRef sharding, PjRtBuffers pjrt_buffers,
std::shared_ptr<const xla::PjRtLayout> layout);
// Shorthand for a single-shard array construction.
// See `PjRtCompatibleClient::CreatePjRtArray()` for the meaning of
// `has_custom_layout`.
static absl::StatusOr<tsl::RCReference<PjRtArray>> Create(
PjRtCompatibleClient* client, std::shared_ptr<PjRtBuffer> pjrt_buffer);
PjRtCompatibleClient* client, std::shared_ptr<PjRtBuffer> pjrt_buffer,
bool has_custom_layout);
// Shorthand for a multi-shard array construction using ConcreteSharding.
// pjrt_buffers must be non-empty.
// `pjrt_buffers` must be non-empty.
// See `PjRtCompatibleClient::CreatePjRtArray()` for the meaning of
// `has_custom_layout`.
// TODO(hyeontaek): Remove this once IFRT Sharding and JAX Sharding is unified
// so that ConcreteSharding can be replaced with a real Sharding.
static absl::StatusOr<tsl::RCReference<PjRtArray>> Create(
PjRtCompatibleClient* client, Shape shape, PjRtBuffers pjrt_buffers);
PjRtCompatibleClient* client, Shape shape, PjRtBuffers pjrt_buffers,
bool has_custom_layout);
// Shorthand for a multi-shard array construction using ConcreteSharding with
// DynamicShape. pjrt_buffers must be non-empty.
// DynamicShape. `pjrt_buffers` must be non-empty.
// See `PjRtCompatibleClient::CreatePjRtArray()` for the meaning of
// `has_custom_layout`.
static absl::StatusOr<tsl::RCReference<PjRtArray>> Create(
PjRtCompatibleClient* client, DynamicShape dynamic_shape,
PjRtBuffers pjrt_buffers);
PjRtBuffers pjrt_buffers, bool has_custom_layout);
// PjRtCompatibleArray implementation.

View File

@ -18,11 +18,18 @@ limitations under the License.
#include "xla/python/ifrt/test_util.h"
int main(int argc, char** argv) {
static constexpr absl::string_view kFilter =
// CpuBuffer::ToLiteral() currently does not respect the layout of the
// destination literal.
static constexpr absl::string_view kFilter =
"-ArrayImplTest."
"MakeArrayFromHostBufferAndCopyToHostBufferWithByteStrides";
"MakeArrayFromHostBufferAndCopyToHostBufferWithByteStrides:"
// Arrays created using `MakeArraysFromHostBufferShards()` do not indicate
// correct custom layouts even if the given layout is a concrete default
// layout. PjRt-IFRT uses `ClientMakeArraysFromHostBufferShards()`
// internally, which lowers `MakeArraysFromHostBufferShards()` call into
// legacy API calls that do not yet support custom layouts, and thus the
// output arrays only can have default layouts.
"ArrayImplTest.MakeArraysFromHostBufferShardsWithLayout";
xla::ifrt::test_util::SetTestFilterIfNotUserSpecified(kFilter);
testing::InitGoogleTest(&argc, argv);

View File

@ -725,6 +725,17 @@ const char kKeyPrefix[] = "ifrt_cross_host_transfer_";
char PjRtCompatibleClient::ID = 0;
char PjRtClient::ID = 0;
absl::StatusOr<tsl::RCReference<PjRtCompatibleArray>>
PjRtCompatibleClient::CreatePjRtArray(std::shared_ptr<PjRtBuffer> pjrt_buffer) {
return CreatePjRtArray(std::move(pjrt_buffer), /*has_custom_layout=*/true);
}
absl::StatusOr<tsl::RCReference<PjRtCompatibleArray>>
PjRtCompatibleClient::CreatePjRtArray(Shape shape, PjRtBuffers pjrt_buffers) {
return CreatePjRtArray(std::move(shape), std::move(pjrt_buffers),
/*has_custom_layout=*/true);
}
absl::StatusOr<std::unique_ptr<PjRtClient>> PjRtClient::Create(
PjRtClient::CreateOptions options) {
auto client =
@ -955,16 +966,21 @@ absl::StatusOr<DeviceListRef> PjRtClient::MakeDeviceList(
const AttributeMap& PjRtClient::Attributes() const { return attributes_; }
absl::StatusOr<tsl::RCReference<PjRtCompatibleArray>>
PjRtClient::CreatePjRtArray(std::shared_ptr<PjRtBuffer> pjrt_buffer) {
TF_ASSIGN_OR_RETURN(auto array,
PjRtArray::Create(this, std::move(pjrt_buffer)));
PjRtClient::CreatePjRtArray(std::shared_ptr<PjRtBuffer> pjrt_buffer,
bool has_custom_layout) {
TF_ASSIGN_OR_RETURN(
auto array,
PjRtArray::Create(this, std::move(pjrt_buffer), has_custom_layout));
return tsl::RCReference<PjRtCompatibleArray>(std::move(array));
}
absl::StatusOr<tsl::RCReference<PjRtCompatibleArray>>
PjRtClient::CreatePjRtArray(Shape shape, PjRtBuffers pjrt_buffers) {
PjRtClient::CreatePjRtArray(Shape shape, PjRtBuffers pjrt_buffers,
bool has_custom_layout) {
std::shared_ptr<const xla::PjRtLayout> layout;
TF_ASSIGN_OR_RETURN(auto array, PjRtArray::Create(this, std::move(shape),
std::move(pjrt_buffers)));
std::move(pjrt_buffers),
has_custom_layout));
return tsl::RCReference<PjRtCompatibleArray>(std::move(array));
}
@ -1055,7 +1071,8 @@ absl::StatusOr<ArrayRef> PjRtClient::MakeArrayFromHostBuffer(
}
buffers.push_back(std::move(buffer));
}
auto layout = buffers.front()->layout();
// `MakeArrayFromHostBuffer` only creates buffers with a default layout.
std::shared_ptr<const xla::PjRtLayout> layout = nullptr;
return PjRtArray::Create(this, dtype, std::move(shape), std::move(sharding),
std::move(buffers), std::move(layout));
}
@ -1121,12 +1138,11 @@ absl::StatusOr<std::vector<ArrayRef>> PjRtClient::MakeErrorArrays(
error, xla_shape,
tensorflow::down_cast<PjRtMemory*>(memory)->pjrt_memory()));
}
auto layout = buffers.front()->layout();
TF_ASSIGN_OR_RETURN(
arrays.emplace_back(),
PjRtArray::Create(this, array_spec.dtype, std::move(shard_shape),
array_spec.sharding, std::move(buffers),
std::move(layout)));
array_spec.layout));
}
return arrays;
}
@ -1211,16 +1227,8 @@ absl::StatusOr<ArrayRef> PjRtClient::AssembleArrayFromSingleDeviceArrays(
}
// TODO(emilyaf): Remove the following logic once layout is plumbed through.
std::shared_ptr<const xla::PjRtLayout> layout;
if (dtype.kind() == DType::kToken) {
layout = std::make_shared<xla::PjRtLayout>(xla::Layout());
} else if (buffers.empty()) {
TF_ASSIGN_OR_RETURN(auto shard_shape, sharding->GetShardShape(shape));
TF_ASSIGN_OR_RETURN(
layout, GetDefaultPjRtLayout(dtype, shard_shape.dims(),
sharding->devices()->devices().front(),
sharding->memory_kind()));
} else {
layout = buffers.front()->layout();
if (!arrays.empty()) {
TF_ASSIGN_OR_RETURN(layout, arrays.front()->pjrt_layout());
}
return PjRtArray::Create(this, dtype, std::move(shape), std::move(sharding),
std::move(buffers), std::move(layout));
@ -1387,16 +1395,6 @@ PjRtClient::CopyArraysForCrossHost(absl::Span<ArrayRef> arrays,
arrays[i]->shared_ptr_sharding()->WithDeviceAssignment(
dst_devices, memory_kind));
TF_ASSIGN_OR_RETURN(auto new_layout, arrays[i]->pjrt_layout());
if (new_layout == nullptr) {
TF_ASSIGN_OR_RETURN(
xla::ifrt::Shape shard_shape,
arrays[i]->sharding().GetShardShape(arrays[i]->shape()));
TF_ASSIGN_OR_RETURN(
new_layout, GetDefaultPjRtLayout(
arrays[i]->dtype(), shard_shape.dims(),
arrays[i]->sharding().devices()->devices().front(),
arrays[i]->sharding().memory_kind()));
}
TF_ASSIGN_OR_RETURN(
new_arrays.emplace_back(),
PjRtArray::Create(this, arrays[i]->dtype(), arrays[i]->shape(),

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <memory>
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
#include "absl/base/thread_annotations.h"
@ -91,10 +92,32 @@ class PjRtCompatibleClient
// operations.
virtual xla::PjRtClient* pjrt_client() = 0;
virtual std::shared_ptr<xla::PjRtClient> shared_ptr_pjrt_client() = 0;
// Creates an IFRT `PjRtCompatibleArray` from `PjRtBuffer`(s).
//
// Most array properties will be inferred from the input `PjRtBuffer`(s),
// except for the layout's defaultness that is absent information at the PjRt
// level.
//
// `has_custom_layout` indicates that the layout of the input `PjRtBuffer`(s)
// is intended to be a user-chosen custom layout, and
// `PjRtCompatibleArray::pjrt_layout()` should return a non-null value.
// Treating a default layout as a custom layout is typically allowed in PjRt
// if their concrete layouts match, but it may not pass a strict check that
// unconditionally says a default layout != any non-default layout designed
// for portability. Thus, it is useful for the caller to provide as accurate
// information as possible.
virtual absl::StatusOr<tsl::RCReference<PjRtCompatibleArray>> CreatePjRtArray(
std::shared_ptr<PjRtBuffer> pjrt_buffer) = 0;
std::shared_ptr<PjRtBuffer> pjrt_buffer, bool has_custom_layout) = 0;
virtual absl::StatusOr<tsl::RCReference<PjRtCompatibleArray>> CreatePjRtArray(
Shape shape, PjRtBuffers pjrt_buffers) = 0;
Shape shape, PjRtBuffers pjrt_buffers, bool has_custom_layout) = 0;
// Temporary overloads for API transition.
absl::StatusOr<tsl::RCReference<PjRtCompatibleArray>> CreatePjRtArray(
std::shared_ptr<PjRtBuffer> pjrt_buffer);
absl::StatusOr<tsl::RCReference<PjRtCompatibleArray>> CreatePjRtArray(
Shape shape, PjRtBuffers pjrt_buffers);
virtual absl::StatusOr<PjRtCompatibleDevice*> LookupPjRtDevice(
xla::PjRtDevice* pjrt_device) const = 0;
virtual absl::StatusOr<PjRtCompatibleMemory*> LookupPjRtMemory(
@ -178,9 +201,9 @@ class PjRtClient final
return pjrt_client_;
}
absl::StatusOr<tsl::RCReference<PjRtCompatibleArray>> CreatePjRtArray(
std::shared_ptr<PjRtBuffer> pjrt_buffer) override;
std::shared_ptr<PjRtBuffer> pjrt_buffer, bool has_custom_layout) override;
absl::StatusOr<tsl::RCReference<PjRtCompatibleArray>> CreatePjRtArray(
Shape shape, PjRtBuffers pjrt_buffers) override;
Shape shape, PjRtBuffers pjrt_buffers, bool has_custom_layout) override;
// Client implementation.

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "xla/python/pjrt_ifrt/pjrt_executable.h"
#include <cstddef>
#include <memory>
#include <optional>
#include <string>
@ -27,6 +28,7 @@ limitations under the License.
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/Support/Casting.h"
@ -38,11 +40,13 @@ limitations under the License.
#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h"
#include "xla/layout.h"
#include "xla/pjrt/host_callback.h"
#include "xla/pjrt/layout_mode.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/pjrt/utils.h"
#include "xla/primitive_util.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/attribute_map.h"
@ -147,6 +151,7 @@ absl::StatusOr<std::optional<xla::HloSharding>> GetFirstModuleOutputSharding(
}
// Returns the flattened output memory_kinds of the first module in a
// `PjRtLoadedExecutable`.
// `UnimplementedError` will be converted into `std::nullopt`.
absl::StatusOr<std::optional<std::vector<absl::string_view>>>
GetFirstModuleOutputMemoryKinds(
@ -165,6 +170,39 @@ GetFirstModuleOutputMemoryKinds(
return std::move(output_memory_kinds)->front();
}
// Returns the flattened output layouts of the first module in a
// `PjRtLoadedExecutable`.
// `UnimplementedError` will be converted into a vector of `nullptr`.
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
GetFirstModuleOutputLayouts(
xla::PjRtLoadedExecutable* pjrt_loaded_executable,
absl::Span<const xla::LayoutMode> output_layout_modes) {
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
executable_output_layouts = pjrt_loaded_executable->GetOutputLayouts();
// An unimplemented error is converted into all-default layouts.
if (absl::IsUnimplemented(executable_output_layouts.status())) {
return std::vector<std::shared_ptr<const xla::PjRtLayout>>(
/*size=*/output_layout_modes.size(), /*value=*/nullptr);
}
TF_RETURN_IF_ERROR(executable_output_layouts.status());
std::vector<std::shared_ptr<const xla::PjRtLayout>> output_layouts;
if (executable_output_layouts->size() != output_layout_modes.size()) {
return FailedPrecondition(
"Output memory kinds and output layout modes have different sizes: %d "
"vs. %d",
executable_output_layouts->size(), output_layout_modes.size());
}
output_layouts.reserve(executable_output_layouts->size());
for (int i = 0; i < executable_output_layouts->size(); ++i) {
if (output_layout_modes[i].mode == xla::LayoutMode::Mode::kDefault) {
output_layouts.push_back(nullptr);
} else {
output_layouts.push_back(std::move((*executable_output_layouts)[i]));
}
}
return output_layouts;
}
struct ShapePartialInfo {
std::vector<xla::PrimitiveType> element_types;
std::vector<xla::DimensionVector> dimensions;
@ -188,6 +226,36 @@ absl::StatusOr<ShapePartialInfo> CreateShapePartialInfo(
return partial_info;
}
// Special `xla::GetLayoutModes()` implementation for obtaining layout modes
// from `hlo_module` without serializing it into proto.
static const char* kDelimiter = ";";
static absl::StatusOr<std::vector<LayoutMode>> GetLayoutModesFromFrontendAttr(
absl::string_view attr) {
// SkipEmpty() needed to avoid returning the empty string when attr is empty.
std::vector<std::string> str_modes =
absl::StrSplit(attr, kDelimiter, absl::SkipEmpty());
std::vector<LayoutMode> result;
for (const std::string& str_mode : str_modes) {
TF_ASSIGN_OR_RETURN(LayoutMode mode, LayoutMode::FromString(str_mode));
result.emplace_back(std::move(mode));
}
return result;
}
static absl::StatusOr<std::vector<LayoutMode>> GetLayoutModes(
const HloModule& hlo_module, absl::string_view frontend_attr_name,
size_t num_values) {
const auto& frontend_attrs = hlo_module.frontend_attributes().map();
auto iter = frontend_attrs.find(frontend_attr_name);
if (iter == frontend_attrs.end()) {
// Return all default layouts if frontend attr isn't present.
return std::vector<LayoutMode>(num_values);
}
return GetLayoutModesFromFrontendAttr(iter->second);
}
} // namespace
char PjRtCompatibleExecutable::ID = 0;
@ -228,11 +296,22 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
TF_ASSIGN_OR_RETURN(
auto result_memory_kinds,
GetFirstModuleOutputMemoryKinds(pjrt_loaded_executable.get()));
TF_ASSIGN_OR_RETURN(auto hlo_modules,
pjrt_loaded_executable->GetHloModules());
if (hlo_modules.empty()) {
return FailedPrecondition("Requires at least one HloModule.");
}
TF_ASSIGN_OR_RETURN(std::vector<xla::LayoutMode> output_layout_modes,
GetLayoutModes(*hlo_modules.front(), "out_layout_modes",
result_element_types.size()));
TF_ASSIGN_OR_RETURN(auto output_layouts,
GetFirstModuleOutputLayouts(pjrt_loaded_executable.get(),
output_layout_modes));
return CreateInternal(client, std::move(pjrt_loaded_executable),
result_element_types, result_dimensions,
/*result_hlo_sharding=*/std::nullopt,
result_memory_kinds, loaded_host_callbacks,
std::move(executable_devices));
result_memory_kinds, output_layouts,
loaded_host_callbacks, std::move(executable_devices));
}
static absl::StatusOr<std::vector<xla::Shape>> ResultShapesOfModule(
@ -271,7 +350,10 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
// We have to do process the MLIR before the compile call, since the latter
// will use the MLIR as scratch space, or possibly even deallocate it.
TF_ASSIGN_OR_RETURN(auto result_shapes, ResultShapesOfModule(module));
TF_ASSIGN_OR_RETURN(const std::vector<xla::Shape> result_shapes,
ResultShapesOfModule(module));
TF_ASSIGN_OR_RETURN(const std::vector<xla::LayoutMode> output_layout_modes,
GetOutputLayoutModes(module));
TF_ASSIGN_OR_RETURN(auto pjrt_loaded_executable,
client->pjrt_client()->CompileAndLoad(
@ -290,10 +372,14 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
TF_ASSIGN_OR_RETURN(
auto result_memory_kinds,
GetFirstModuleOutputMemoryKinds(pjrt_loaded_executable.get()));
TF_ASSIGN_OR_RETURN(auto output_layouts,
GetFirstModuleOutputLayouts(
pjrt_loaded_executable.get(), output_layout_modes));
return CreateInternal(client, std::move(pjrt_loaded_executable),
result_element_types, result_dimensions,
/*result_hlo_sharding=*/std::nullopt,
result_memory_kinds, std::move(loaded_host_callbacks),
result_memory_kinds, output_layouts,
std::move(loaded_host_callbacks),
std::move(executable_devices));
} else {
VLOG(3) << "Using full shape";
@ -319,11 +405,14 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
TF_ASSIGN_OR_RETURN(
auto result_memory_kinds,
GetFirstModuleOutputMemoryKinds(pjrt_loaded_executable.get()));
return CreateInternal(client, std::move(pjrt_loaded_executable),
shape_partial_info.element_types,
shape_partial_info.dimensions, result_hlo_sharding,
result_memory_kinds, std::move(loaded_host_callbacks),
std::move(executable_devices));
TF_ASSIGN_OR_RETURN(auto output_layouts,
GetFirstModuleOutputLayouts(
pjrt_loaded_executable.get(), output_layout_modes));
return CreateInternal(
client, std::move(pjrt_loaded_executable),
shape_partial_info.element_types, shape_partial_info.dimensions,
result_hlo_sharding, result_memory_kinds, output_layouts,
std::move(loaded_host_callbacks), std::move(executable_devices));
}
}
@ -334,6 +423,7 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::CreateInternal(
absl::Span<const xla::DimensionVector> result_dimensions,
const std::optional<xla::HloSharding>& result_hlo_sharding,
const std::optional<std::vector<absl::string_view>>& result_memory_kinds,
const std::vector<std::shared_ptr<const xla::PjRtLayout>>& output_layouts,
std::vector<tsl::RCReference<LoadedHostCallback>> loaded_host_callbacks,
DeviceListRef executable_devices) {
// For jit(pmap(...)), the device assignment (passed as `executable_devices`)
@ -493,7 +583,8 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::CreateInternal(
client, std::move(pjrt_loaded_executable), std::move(executable_devices),
std::move(addressable_devices), std::move(loaded_host_callbacks),
std::move(host_send_and_recv_callbacks), std::move(output_dtypes),
std::move(output_shapes), std::move(output_shardings)));
std::move(output_shapes), std::move(output_shardings),
std::move(output_layouts)));
}
PjRtLoadedExecutable::PjRtLoadedExecutable(
@ -504,7 +595,8 @@ PjRtLoadedExecutable::PjRtLoadedExecutable(
std::vector<PjRtHostSendAndRecvLoadedHostCallback*>
host_send_recv_callbacks,
std::vector<DType> output_dtypes, std::vector<Shape> output_shapes,
std::vector<ShardingRef> output_shardings)
std::vector<ShardingRef> output_shardings,
std::vector<std::shared_ptr<const xla::PjRtLayout>> output_layouts)
: client_(client),
pjrt_loaded_executable_(std::move(pjrt_loaded_executable)),
devices_(std::move(devices)),
@ -516,6 +608,7 @@ PjRtLoadedExecutable::PjRtLoadedExecutable(
output_dtypes_(std::move(output_dtypes)),
output_shapes_(std::move(output_shapes)),
output_shardings_(std::move(output_shardings)),
output_layouts_(std::move(output_layouts)),
user_context_(UserContextScope::current()) {}
PjRtLoadedExecutable::~PjRtLoadedExecutable() = default;
@ -719,39 +812,6 @@ PjRtLoadedExecutable::Execute(absl::Span<ArrayRef> args,
// memory_kind shares the same Sharding object.
absl::flat_hash_map<MemoryKind, ShardingRef> single_device_shardings;
// TODO(emilyaf): Simplify the handling of layouts here when they're plumbed
// through from JAX.
std::vector<std::shared_ptr<const xla::PjRtLayout>> layouts;
layouts.reserve(num_outputs);
if (!pjrt_outputs.empty()) {
for (int i = 0; i < num_outputs; ++i) {
auto layout = output_dtypes_[i].kind() == xla::ifrt::DType::kToken
? std::make_shared<xla::PjRtLayout>(xla::Layout())
: pjrt_outputs.front()[i]->layout();
layouts.push_back(std::move(layout));
}
} else {
auto maybe_layouts = GetOutputLayouts();
if (absl::IsUnimplemented(maybe_layouts.status())) {
for (int i = 0; i < num_outputs; ++i) {
std::shared_ptr<const xla::PjRtLayout> layout;
if (output_dtypes_[i].kind() == xla::ifrt::DType::kToken) {
layout = std::make_shared<xla::PjRtLayout>(xla::Layout());
} else {
TF_ASSIGN_OR_RETURN(layout,
client_->GetDefaultPjRtLayout(
output_dtypes_[i], output_shapes_[i].dims(),
devices_->devices().front(),
output_shardings_[i]->memory_kind()));
}
layouts.push_back(std::move(layout));
}
} else {
TF_RETURN_IF_ERROR(maybe_layouts.status());
layouts = *std::move(maybe_layouts);
}
}
for (int i = 0; i < num_outputs; ++i) {
PjRtArray::PjRtBuffers buffers;
buffers.reserve(num_computations);
@ -792,7 +852,7 @@ PjRtLoadedExecutable::Execute(absl::Span<ArrayRef> args,
}
outputs.push_back(*PjRtArray::Create(
client_, output_dtypes_[i], output_shapes_[i], *std::move(sharding),
std::move(buffers), std::move(layouts[i])));
std::move(buffers), output_layouts_[i]));
}
ExecuteResult result;

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/attribute_map.h"
#include "xla/python/ifrt/device.h"
@ -127,6 +128,9 @@ class PjRtExecutable final
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
GetOutputLayouts() const override {
// TODO(hyeontaek): Return `output_layouts_` instead, which can distinguish
// between default and custom layouts, once the users of
// `GetOutputLayouts()` understand `nullptr` elements.
DCHECK(this);
return pjrt_executable_->GetOutputLayouts();
}
@ -335,6 +339,7 @@ class PjRtLoadedExecutable final
absl::Span<const xla::DimensionVector> result_dimensions,
const std::optional<xla::HloSharding>& result_hlo_sharding,
const std::optional<std::vector<absl::string_view>>& result_memory_kinds,
const std::vector<std::shared_ptr<const xla::PjRtLayout>>& output_layouts,
std::vector<tsl::RCReference<LoadedHostCallback>> loaded_host_callbacks,
DeviceListRef executable_devices);
@ -347,7 +352,8 @@ class PjRtLoadedExecutable final
std::vector<PjRtHostSendAndRecvLoadedHostCallback*>
host_send_recv_callbacks,
std::vector<DType> output_dtypes, std::vector<Shape> output_shapes,
std::vector<ShardingRef> output_shardings);
std::vector<ShardingRef> output_shardings,
std::vector<std::shared_ptr<const xla::PjRtLayout>> output_layouts);
PjRtClient* client_;
std::shared_ptr<xla::PjRtLoadedExecutable> pjrt_loaded_executable_;
@ -366,6 +372,7 @@ class PjRtLoadedExecutable final
std::vector<DType> output_dtypes_;
std::vector<Shape> output_shapes_;
std::vector<ShardingRef> output_shardings_;
std::vector<std::shared_ptr<const xla::PjRtLayout>> output_layouts_;
const xla::ifrt::UserContextRef user_context_;
};

View File

@ -401,8 +401,18 @@ TEST_F(ReshardTest, DifferentDestinationLayout) {
// Make sure that the destination layout is actually different from the source
// layout in order to ensure the test coverage.
TF_ASSERT_OK_AND_ASSIGN(const auto src_layout, src_array->pjrt_layout());
ASSERT_NE(src_layout, nullptr);
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<const xla::PjRtLayout> src_layout,
src_array->pjrt_layout());
if (src_layout == nullptr) {
TF_ASSERT_OK_AND_ASSIGN(
Shape shard_shape,
src_array->sharding().GetShardShape(src_array->shape()));
TF_ASSERT_OK_AND_ASSIGN(
src_layout, client_->GetDefaultPjRtLayout(
src_array->dtype(), shard_shape.dims(),
src_array->sharding().devices()->devices().front(),
src_array->sharding().memory_kind()));
}
ASSERT_NE(src_layout->xla_layout(), dst_array_spec.layout->xla_layout());
TF_ASSERT_OK_AND_ASSIGN(

View File

@ -249,15 +249,15 @@ absl::Status PjRtTransferServer::CrossHostPull(
xla::DimensionVector(shape.dims().begin(), shape.dims().end())};
shape_specs.push_back(shape_spec);
auto pjrt_layout = arrays[i]->pjrt_layout();
absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>> pjrt_layout =
arrays[i]->pjrt_layout();
std::optional<xla::Layout> layout;
if (pjrt_layout.ok() && *pjrt_layout == nullptr) {
TF_ASSIGN_OR_RETURN(
xla::ifrt::Shape shard_shape,
arrays[i]->sharding().GetShardShape(arrays[i]->shape()));
TF_ASSIGN_OR_RETURN(
std::shared_ptr<const xla::PjRtLayout> layout,
arrays[i]->client()->GetDefaultPjRtLayout(
pjrt_layout, arrays[i]->client()->GetDefaultPjRtLayout(
arrays[i]->dtype(), shard_shape.dims(),
arrays[i]->sharding().devices()->devices().front(),
arrays[i]->sharding().memory_kind()));

View File

@ -89,8 +89,10 @@ absl::StatusOr<SingleBufferCopyPlan> SetupTransferDestList(
size_t copy_size = xla::ShapeUtil::ByteSizeOf(shape);
results.dests.push_back(MakeDmaDestination(atm, 0, copy_size));
TF_ASSIGN_OR_RETURN(auto arr,
ifrt_client->CreatePjRtArray(atm->RetrieveBuffer(0)));
// `CreateBuffersForAsyncHostToDevice` uses a default layout.
TF_ASSIGN_OR_RETURN(
auto arr, ifrt_client->CreatePjRtArray(atm->RetrieveBuffer(0),
/*has_custom_layout=*/false));
results.arrays.push_back(std::move(arr));
return results;
}

View File

@ -18,6 +18,7 @@ limitations under the License.
// An increasing version number to protect jax code against breaking changes.
// In JAX, reference this via jax._src.lib.ifrt_version.
#define JAX_IFRT_VERSION_NUMBER 33
#define JAX_IFRT_VERSION_NUMBER \
34 // Explicit `has_custom_layout` argument in PjRt-IFRT Array creation.
#endif // XLA_PYTHON_VERSION_H_