[IFRT] Replace Client::GetDefaultLayout() and Array::layout() with a new version

`Client::GetDefaultLayout()` and `Array::layout()` are replaced to use
`CustomLayoutRef`.
* `Client::GetDefaultLayout()` is functionally equivalent to `Client::GetDefaultPjRtLayout()`, but using IFRT types.
* `Array::layout()` is slightly semantics changing as it cannot return an error anymore, but it must return some layout (where a compact layout is typically a valid choice). *In the future,* this method will see a further semantics change that a default layout is indicated `nullptr` (i.e., of `LayoutRef` type) instead of its concrete layout.

Subsequent changes will introduce the initial method implementations
that simply wrap PjRt layouts using `xla::ifrt::PjRtLayout::Create()`. This
implementation needs to be defined in individual runtimes because
`xla::ifrt::PjRtLayout` is defined in PjRt-IFRT and inaccessible from
top-level IFRT.

PiperOrigin-RevId: 813940989
This commit is contained in:
Hyeontaek Lim 2025-10-01 15:30:17 -07:00 committed by TensorFlower Gardener
parent c10292b99e
commit c1a1cee310
5 changed files with 55 additions and 13 deletions

View File

@ -22,11 +22,13 @@ limitations under the License.
#include <vector>
#include "absl/base/attributes.h"
#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/pjrt/pjrt_layout.h"
#include "xla/python/ifrt/dtype.h"
#include "xla/python/ifrt/layout.h"
#include "xla/python/ifrt/shape.h"
#include "xla/python/ifrt/sharding.h"
#include "xla/python/ifrt/value.h"
@ -79,10 +81,10 @@ class Array : public llvm::RTTIExtends<Array, Value> {
// return UNIMPLEMENTED instead.
virtual absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>> pjrt_layout()
const = 0;
// Legacy name for `pjrt_layout()`. Will be removed, and then re-introduced as
// a new signature that returns `xla::ifrt::LayoutRef`.
absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>> layout() const {
return pjrt_layout();
virtual CustomLayoutRef layout() const {
// TODO(hyeontaek): Change to a pure virtual method once all implementations
// override this method.
CHECK(false) << "Placeholder; do not use yet";
}
// Breaks an array up into per-device arrays. This is the elimination

View File

@ -15,10 +15,28 @@ limitations under the License.
#include "xla/python/ifrt/client.h"
#include <cstdint>
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/dtype.h"
#include "xla/python/ifrt/layout.h"
#include "xla/python/ifrt/memory.h"
#include "xla/python/ifrt/shape.h"
#include "xla/python/ifrt/sharding.h"
namespace xla {
namespace ifrt {
char Client::ID = 0;
absl::StatusOr<CustomLayoutRef> Client::GetDefaultLayout(
DType dtype, absl::Span<const int64_t> shard_dims, Device* device,
xla::ifrt::MemoryKind memory_kind) const {
return GetDefaultLayout(dtype, Shape(shard_dims),
SingleDeviceSharding::Create(device, memory_kind));
}
} // namespace ifrt
} // namespace xla

View File

@ -24,8 +24,8 @@ limitations under the License.
#include <vector>
#include "absl/base/macros.h"
#include "absl/base/nullability.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
@ -41,6 +41,7 @@ limitations under the License.
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/dtype.h"
#include "xla/python/ifrt/layout.h"
#include "xla/python/ifrt/memory.h"
#include "xla/python/ifrt/remap_plan.h"
#include "xla/python/ifrt/shape.h"
@ -344,15 +345,20 @@ class Client : public llvm::RTTIExtends<Client, llvm::RTTIRoot> {
Device* device,
xla::ifrt::MemoryKind memory_kind) const = 0;
// Legacy name for `GetDefaultPjRtLayout()`. Will be removed, and then
// re-introduced as a new signature that returns `xla::ifrt::CustomLayoutRef`.
// TODO(hyeontaek): Change the API to take `Shape` and `Sharding` instead of
// single-shard dimensions and device.
absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>> GetDefaultLayout(
DType dtype, absl::Span<const int64_t> dims, Device* device,
xla::ifrt::MemoryKind memory_kind) const {
return GetDefaultPjRtLayout(dtype, dims, device, memory_kind);
// Returns the default layout for an array with `dtype`, `shape`, and
// `sharding`.
virtual absl::StatusOr<CustomLayoutRef> GetDefaultLayout(
DType dtype, const Shape& shape, const ShardingRef& sharding) const {
// TODO(hyeontaek): Change to a pure virtual method once all implementations
// override this method.
CHECK(false) << "Placeholder; do not use yet";
return absl::UnimplementedError("Not implemented yet");
}
// Helper method for `GetDefaultLayout` for when shard shape dims are known.
// TODO(hyeontaek): Remove this sugar API once the transition is complete.
absl::StatusOr<CustomLayoutRef> GetDefaultLayout(
DType dtype, absl::Span<const int64_t> shard_dims, Device* device,
xla::ifrt::MemoryKind memory_kind) const;
// Returns a UserContext that captures the current context information such as
// the stack trace. IFRT implementations that do not support UserContext will

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/dtype.h"
#include "xla/python/ifrt/layout.h"
#include "xla/python/ifrt/memory.h"
#include "xla/python/ifrt/remap_plan.h"
#include "xla/python/ifrt/shape.h"
@ -85,6 +86,9 @@ MockArray::MockArray(xla::ifrt::ArrayRef delegated)
[this]() -> absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>> {
return delegated_->pjrt_layout();
});
ON_CALL(*this, layout).WillByDefault([this]() -> CustomLayoutRef {
return delegated_->layout();
});
ON_CALL(*this, DisassembleIntoSingleDeviceArrays(_, _))
.WillByDefault(
[this](ArrayCopySemantics array_copy_semantics,
@ -231,6 +235,13 @@ MockClient::MockClient(std::unique_ptr<xla::ifrt::Client> delegated)
return delegated_->GetDefaultPjRtLayout(dtype, dims, device,
memory_kind);
});
ON_CALL(*this, GetDefaultLayout)
.WillByDefault(
[this](
DType dtype, const Shape& shape,
const ShardingRef& sharding) -> absl::StatusOr<CustomLayoutRef> {
return delegated_->GetDefaultLayout(dtype, shape, sharding);
});
ON_CALL(*this, Attributes).WillByDefault([this]() -> const AttributeMap& {
return delegated_->Attributes();
});

View File

@ -48,6 +48,7 @@ limitations under the License.
#include "xla/python/ifrt/executable_serdes.h"
#include "xla/python/ifrt/host_callback.h"
#include "xla/python/ifrt/index_domain.h"
#include "xla/python/ifrt/layout.h"
#include "xla/python/ifrt/memory.h"
#include "xla/python/ifrt/program.h"
#include "xla/python/ifrt/remap_plan.h"
@ -83,6 +84,7 @@ class MockArray : public llvm::RTTIExtends<MockArray, Array> {
MOCK_METHOD(ShardingRef, shared_ptr_sharding, (), (const, final));
MOCK_METHOD(absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>>,
pjrt_layout, (), (const, final));
MOCK_METHOD(CustomLayoutRef, layout, (), (const, final));
MOCK_METHOD(UserContextRef, user_context, (), (const, final));
MOCK_METHOD(absl::StatusOr<std::vector<ArrayRef>>,
DisassembleIntoSingleDeviceArrays,
@ -184,6 +186,9 @@ class MockClient : public llvm::RTTIExtends<MockClient, Client> {
(xla::ifrt::DType dtype, absl::Span<const int64_t> dims,
xla::ifrt::Device* device, xla::ifrt::MemoryKind memory_kind),
(const, final));
MOCK_METHOD(absl::StatusOr<CustomLayoutRef>, GetDefaultLayout,
(DType dtype, const Shape& shape, const ShardingRef& sharding),
(const, final));
MOCK_METHOD(tsl::RCReference<xla::ifrt::UserContext>, CreateUserContext, (),
(final));
// LINT.ThenChange(mock.cc:MockClientDelegation)