mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[IFRT] Replace Client::GetDefaultLayout() and Array::layout() with a new version
`Client::GetDefaultLayout()` and `Array::layout()` are replaced to use `CustomLayoutRef`. * `Client::GetDefaultLayout()` is functionally equivalent to `Client::GetDefaultPjRtLayout()`, but using IFRT types. * `Array::layout()` is slightly semantics changing as it cannot return an error anymore, but it must return some layout (where a compact layout is typically a valid choice). *In the future,* this method will see a further semantics change that a default layout is indicated `nullptr` (i.e., of `LayoutRef` type) instead of its concrete layout. Subsequent changes will introduce the initial method implementations that simply wrap PjRt layouts using `xla::ifrt::PjRtLayout::Create()`. This implementation needs to be defined in individual runtimes because `xla::ifrt::PjRtLayout` is defined in PjRt-IFRT and inaccessible from top-level IFRT. PiperOrigin-RevId: 813940989
This commit is contained in:
parent
c10292b99e
commit
c1a1cee310
10
third_party/xla/xla/python/ifrt/array.h
vendored
10
third_party/xla/xla/python/ifrt/array.h
vendored
|
|
@ -22,11 +22,13 @@ limitations under the License.
|
|||
#include <vector>
|
||||
|
||||
#include "absl/base/attributes.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/Support/ExtensibleRTTI.h"
|
||||
#include "xla/pjrt/pjrt_layout.h"
|
||||
#include "xla/python/ifrt/dtype.h"
|
||||
#include "xla/python/ifrt/layout.h"
|
||||
#include "xla/python/ifrt/shape.h"
|
||||
#include "xla/python/ifrt/sharding.h"
|
||||
#include "xla/python/ifrt/value.h"
|
||||
|
|
@ -79,10 +81,10 @@ class Array : public llvm::RTTIExtends<Array, Value> {
|
|||
// return UNIMPLEMENTED instead.
|
||||
virtual absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>> pjrt_layout()
|
||||
const = 0;
|
||||
// Legacy name for `pjrt_layout()`. Will be removed, and then re-introduced as
|
||||
// a new signature that returns `xla::ifrt::LayoutRef`.
|
||||
absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>> layout() const {
|
||||
return pjrt_layout();
|
||||
virtual CustomLayoutRef layout() const {
|
||||
// TODO(hyeontaek): Change to a pure virtual method once all implementations
|
||||
// override this method.
|
||||
CHECK(false) << "Placeholder; do not use yet";
|
||||
}
|
||||
|
||||
// Breaks an array up into per-device arrays. This is the elimination
|
||||
|
|
|
|||
18
third_party/xla/xla/python/ifrt/client.cc
vendored
18
third_party/xla/xla/python/ifrt/client.cc
vendored
|
|
@ -15,10 +15,28 @@ limitations under the License.
|
|||
|
||||
#include "xla/python/ifrt/client.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/python/ifrt/device.h"
|
||||
#include "xla/python/ifrt/dtype.h"
|
||||
#include "xla/python/ifrt/layout.h"
|
||||
#include "xla/python/ifrt/memory.h"
|
||||
#include "xla/python/ifrt/shape.h"
|
||||
#include "xla/python/ifrt/sharding.h"
|
||||
|
||||
namespace xla {
|
||||
namespace ifrt {
|
||||
|
||||
char Client::ID = 0;
|
||||
|
||||
absl::StatusOr<CustomLayoutRef> Client::GetDefaultLayout(
|
||||
DType dtype, absl::Span<const int64_t> shard_dims, Device* device,
|
||||
xla::ifrt::MemoryKind memory_kind) const {
|
||||
return GetDefaultLayout(dtype, Shape(shard_dims),
|
||||
SingleDeviceSharding::Create(device, memory_kind));
|
||||
}
|
||||
|
||||
} // namespace ifrt
|
||||
} // namespace xla
|
||||
|
|
|
|||
24
third_party/xla/xla/python/ifrt/client.h
vendored
24
third_party/xla/xla/python/ifrt/client.h
vendored
|
|
@ -24,8 +24,8 @@ limitations under the License.
|
|||
#include <vector>
|
||||
|
||||
#include "absl/base/macros.h"
|
||||
#include "absl/base/nullability.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
|
|
@ -41,6 +41,7 @@ limitations under the License.
|
|||
#include "xla/python/ifrt/device.h"
|
||||
#include "xla/python/ifrt/device_list.h"
|
||||
#include "xla/python/ifrt/dtype.h"
|
||||
#include "xla/python/ifrt/layout.h"
|
||||
#include "xla/python/ifrt/memory.h"
|
||||
#include "xla/python/ifrt/remap_plan.h"
|
||||
#include "xla/python/ifrt/shape.h"
|
||||
|
|
@ -344,15 +345,20 @@ class Client : public llvm::RTTIExtends<Client, llvm::RTTIRoot> {
|
|||
Device* device,
|
||||
xla::ifrt::MemoryKind memory_kind) const = 0;
|
||||
|
||||
// Legacy name for `GetDefaultPjRtLayout()`. Will be removed, and then
|
||||
// re-introduced as a new signature that returns `xla::ifrt::CustomLayoutRef`.
|
||||
// TODO(hyeontaek): Change the API to take `Shape` and `Sharding` instead of
|
||||
// single-shard dimensions and device.
|
||||
absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>> GetDefaultLayout(
|
||||
DType dtype, absl::Span<const int64_t> dims, Device* device,
|
||||
xla::ifrt::MemoryKind memory_kind) const {
|
||||
return GetDefaultPjRtLayout(dtype, dims, device, memory_kind);
|
||||
// Returns the default layout for an array with `dtype`, `shape`, and
|
||||
// `sharding`.
|
||||
virtual absl::StatusOr<CustomLayoutRef> GetDefaultLayout(
|
||||
DType dtype, const Shape& shape, const ShardingRef& sharding) const {
|
||||
// TODO(hyeontaek): Change to a pure virtual method once all implementations
|
||||
// override this method.
|
||||
CHECK(false) << "Placeholder; do not use yet";
|
||||
return absl::UnimplementedError("Not implemented yet");
|
||||
}
|
||||
// Helper method for `GetDefaultLayout` for when shard shape dims are known.
|
||||
// TODO(hyeontaek): Remove this sugar API once the transition is complete.
|
||||
absl::StatusOr<CustomLayoutRef> GetDefaultLayout(
|
||||
DType dtype, absl::Span<const int64_t> shard_dims, Device* device,
|
||||
xla::ifrt::MemoryKind memory_kind) const;
|
||||
|
||||
// Returns a UserContext that captures the current context information such as
|
||||
// the stack trace. IFRT implementations that do not support UserContext will
|
||||
|
|
|
|||
11
third_party/xla/xla/python/ifrt/mock.cc
vendored
11
third_party/xla/xla/python/ifrt/mock.cc
vendored
|
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||
#include "xla/python/ifrt/device.h"
|
||||
#include "xla/python/ifrt/device_list.h"
|
||||
#include "xla/python/ifrt/dtype.h"
|
||||
#include "xla/python/ifrt/layout.h"
|
||||
#include "xla/python/ifrt/memory.h"
|
||||
#include "xla/python/ifrt/remap_plan.h"
|
||||
#include "xla/python/ifrt/shape.h"
|
||||
|
|
@ -85,6 +86,9 @@ MockArray::MockArray(xla::ifrt::ArrayRef delegated)
|
|||
[this]() -> absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>> {
|
||||
return delegated_->pjrt_layout();
|
||||
});
|
||||
ON_CALL(*this, layout).WillByDefault([this]() -> CustomLayoutRef {
|
||||
return delegated_->layout();
|
||||
});
|
||||
ON_CALL(*this, DisassembleIntoSingleDeviceArrays(_, _))
|
||||
.WillByDefault(
|
||||
[this](ArrayCopySemantics array_copy_semantics,
|
||||
|
|
@ -231,6 +235,13 @@ MockClient::MockClient(std::unique_ptr<xla::ifrt::Client> delegated)
|
|||
return delegated_->GetDefaultPjRtLayout(dtype, dims, device,
|
||||
memory_kind);
|
||||
});
|
||||
ON_CALL(*this, GetDefaultLayout)
|
||||
.WillByDefault(
|
||||
[this](
|
||||
DType dtype, const Shape& shape,
|
||||
const ShardingRef& sharding) -> absl::StatusOr<CustomLayoutRef> {
|
||||
return delegated_->GetDefaultLayout(dtype, shape, sharding);
|
||||
});
|
||||
ON_CALL(*this, Attributes).WillByDefault([this]() -> const AttributeMap& {
|
||||
return delegated_->Attributes();
|
||||
});
|
||||
|
|
|
|||
5
third_party/xla/xla/python/ifrt/mock.h
vendored
5
third_party/xla/xla/python/ifrt/mock.h
vendored
|
|
@ -48,6 +48,7 @@ limitations under the License.
|
|||
#include "xla/python/ifrt/executable_serdes.h"
|
||||
#include "xla/python/ifrt/host_callback.h"
|
||||
#include "xla/python/ifrt/index_domain.h"
|
||||
#include "xla/python/ifrt/layout.h"
|
||||
#include "xla/python/ifrt/memory.h"
|
||||
#include "xla/python/ifrt/program.h"
|
||||
#include "xla/python/ifrt/remap_plan.h"
|
||||
|
|
@ -83,6 +84,7 @@ class MockArray : public llvm::RTTIExtends<MockArray, Array> {
|
|||
MOCK_METHOD(ShardingRef, shared_ptr_sharding, (), (const, final));
|
||||
MOCK_METHOD(absl::StatusOr<std::shared_ptr<const xla::PjRtLayout>>,
|
||||
pjrt_layout, (), (const, final));
|
||||
MOCK_METHOD(CustomLayoutRef, layout, (), (const, final));
|
||||
MOCK_METHOD(UserContextRef, user_context, (), (const, final));
|
||||
MOCK_METHOD(absl::StatusOr<std::vector<ArrayRef>>,
|
||||
DisassembleIntoSingleDeviceArrays,
|
||||
|
|
@ -184,6 +186,9 @@ class MockClient : public llvm::RTTIExtends<MockClient, Client> {
|
|||
(xla::ifrt::DType dtype, absl::Span<const int64_t> dims,
|
||||
xla::ifrt::Device* device, xla::ifrt::MemoryKind memory_kind),
|
||||
(const, final));
|
||||
MOCK_METHOD(absl::StatusOr<CustomLayoutRef>, GetDefaultLayout,
|
||||
(DType dtype, const Shape& shape, const ShardingRef& sharding),
|
||||
(const, final));
|
||||
MOCK_METHOD(tsl::RCReference<xla::ifrt::UserContext>, CreateUserContext, (),
|
||||
(final));
|
||||
// LINT.ThenChange(mock.cc:MockClientDelegation)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user