mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[IFRT Proxy] Array::pjrt_layout() uses nullptr to indicate a default layout
IFRT Proxy now returns a `nullptr` if it knows that the Array layout represents a default layout. The user code previously has been migrated to handle this new behavior gracefully, obtaining a concrete default layout as before. Caveat: IFRT Proxy client infers the layout of the output arrays from `LoadedExecutable::GetOutputLayouts()`, which always concrete layouts today. Thus, these output arrays would use concrete layouts for default layouts, even if the arrays on the server side use `nullptr` for default layouts. This behavior is currently acceptable where all users convert the layout into a concrete one before using it, while this behavior will eventually change so that IFRT Proxy client reflects the array layouts on the server side more accurately. PiperOrigin-RevId: 821741105
This commit is contained in:
parent
0e09f486e7
commit
cc9fd2b254
|
|
@ -589,10 +589,11 @@ absl::StatusOr<xla::ifrt::ArrayRef> Array::AssembleArrayFromSingleDeviceArrays(
|
|||
// We assume that all shards have the same layout.
|
||||
const xla::ifrt::ArrayRef& rcref = arrays[0];
|
||||
Array* array = llvm::cast<Array>(rcref.get());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<const xla::PjRtLayout> layout,
|
||||
array->pjrt_layout());
|
||||
return xla::ifrt::ArrayRef(tsl::MakeRef<Array>(
|
||||
client, std::move(rpc_helper), dtype, std::move(shape),
|
||||
std::move(sharding), result_handle, array->custom_layout()));
|
||||
std::move(sharding), result_handle, std::move(layout)));
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<xla::ifrt::ArrayRef>> Array::RemapArrays(
|
||||
|
|
@ -664,7 +665,9 @@ absl::StatusOr<std::vector<xla::ifrt::ArrayRef>> Array::RemapArrays(
|
|||
if (output_layouts[mapping.out_array] == nullptr) {
|
||||
const xla::ifrt::ArrayRef& rcref = arrays[mapping.in_array];
|
||||
Array* array = llvm::cast<Array>(rcref.get());
|
||||
output_layouts[mapping.out_array] = array->custom_layout();
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<const xla::PjRtLayout> layout,
|
||||
array->pjrt_layout());
|
||||
output_layouts[mapping.out_array] = std::move(layout);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -751,8 +754,7 @@ Array::DisassembleIntoSingleDeviceArrays(
|
|||
for (int i = 0; i < result_handles.size(); ++i) {
|
||||
result.push_back(xla::ifrt::ArrayRef(tsl::MakeRef<Array>(
|
||||
client_, rpc_helper_, dtype_, std::move(shape_and_shardings[i].first),
|
||||
std::move(shape_and_shardings[i].second), result_handles[i],
|
||||
this->custom_layout())));
|
||||
std::move(shape_and_shardings[i].second), result_handles[i], layout_)));
|
||||
}
|
||||
|
||||
return result;
|
||||
|
|
@ -792,7 +794,7 @@ absl::StatusOr<xla::ifrt::ArrayRef> Array::FullyReplicatedShard(
|
|||
|
||||
return xla::ifrt::ArrayRef(tsl::MakeRef<Array>(
|
||||
client_, rpc_helper_, dtype_, shape_, std::move(single_device_sharding),
|
||||
result_handle, this->custom_layout()));
|
||||
result_handle, layout_));
|
||||
}
|
||||
|
||||
tsl::Future<> Array::CopyToStringHostBuffer(
|
||||
|
|
@ -940,15 +942,7 @@ tsl::Future<> Array::CopyToHostBuffer(
|
|||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<const PjRtLayout>> Array::pjrt_layout() const {
|
||||
absl::MutexLock l(mu_);
|
||||
if (custom_layout_ != nullptr) {
|
||||
return custom_layout_;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto shard_shape, sharding_->GetShardShape(shape_));
|
||||
return client_->GetDefaultPjRtLayout(dtype_, shard_shape.dims(),
|
||||
sharding_->devices()->devices().front(),
|
||||
sharding_->memory_kind());
|
||||
return layout_;
|
||||
}
|
||||
|
||||
xla::ifrt::Client* Array::client() const { return client_; }
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class Array final : public llvm::RTTIExtends<Array, xla::ifrt::Array> {
|
|||
dtype_(dtype),
|
||||
shape_(std::move(shape)),
|
||||
sharding_(std::move(sharding)),
|
||||
custom_layout_(std::move(layout)),
|
||||
layout_(std::move(layout)),
|
||||
user_context_(UserContextScope::current()),
|
||||
handle_(arr_handle) {}
|
||||
|
||||
|
|
@ -140,10 +140,6 @@ class Array final : public llvm::RTTIExtends<Array, xla::ifrt::Array> {
|
|||
return handle_;
|
||||
}
|
||||
|
||||
std::shared_ptr<const xla::PjRtLayout> custom_layout() const {
|
||||
return custom_layout_;
|
||||
}
|
||||
|
||||
xla::ifrt::Client* client() const override;
|
||||
tsl::Future<> GetReadyFuture() const override;
|
||||
tsl::Future<> Delete() override;
|
||||
|
|
@ -191,11 +187,7 @@ class Array final : public llvm::RTTIExtends<Array, xla::ifrt::Array> {
|
|||
const DType dtype_;
|
||||
const Shape shape_;
|
||||
const ShardingRef sharding_;
|
||||
|
||||
// This is layout explicitly supplied at creation time. we explicitly
|
||||
// distinguish it from default layouts since some functions
|
||||
// behaves differently depending on where the layout came from.
|
||||
const std::shared_ptr<const xla::PjRtLayout> custom_layout_;
|
||||
const std::shared_ptr<const xla::PjRtLayout> layout_;
|
||||
|
||||
const UserContextRef user_context_;
|
||||
|
||||
|
|
|
|||
|
|
@ -161,14 +161,11 @@ TEST_F(ArrayTest, FullyReplicatedShard) {
|
|||
}
|
||||
|
||||
TEST_F(ArrayTest, GetDefaultPjRtLayoutSuccess) {
|
||||
ON_CALL(*mock_client_, GetDefaultPjRtLayout).WillByDefault(Return(kLayout1));
|
||||
|
||||
auto array = tsl::MakeRef<Array>(
|
||||
mock_client_.get(), rpc_helper_, DType(DType::Kind::kBF16), Shape({}),
|
||||
sharding_, ArrayHandle{1234}, /*layout=*/nullptr);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto layout_1, array->pjrt_layout());
|
||||
ASSERT_NE(layout_1, nullptr);
|
||||
EXPECT_EQ(*layout_1, *kLayout1);
|
||||
EXPECT_EQ(layout_1, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(ArrayTest, GetCustomLayoutSuccess) {
|
||||
|
|
@ -306,8 +303,7 @@ TEST_F(ArrayTest, AssembleArrayFromSingleDeviceArraysDefaultPjRtLayoutSuccess) {
|
|||
SingleDeviceShardSemantics::kAllShards);
|
||||
TF_ASSERT_OK(result.status());
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto layout, result.value()->pjrt_layout());
|
||||
ASSERT_NE(layout, nullptr);
|
||||
EXPECT_EQ(*layout, *kLayout1);
|
||||
EXPECT_EQ(layout, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(ArrayTest, RemapArraysSuccess) {
|
||||
|
|
|
|||
|
|
@ -361,10 +361,12 @@ absl::StatusOr<std::vector<xla::ifrt::ArrayRef>> Client::CopyArrays(
|
|||
arrays[i]->sharding().WithDeviceAssignment(devices, memory_kind));
|
||||
auto* proxy_array = llvm::cast<xla::ifrt::proxy::Array>(arrays[i].get());
|
||||
CHECK(proxy_array != nullptr);
|
||||
new_arrays.push_back(tsl::MakeRef<Array>(
|
||||
this, rpc_helper_, arrays[i]->dtype(), arrays[i]->shape(),
|
||||
std::move(new_sharding), ArrayHandle{result_handles[i]},
|
||||
/*layout=*/proxy_array->custom_layout()));
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<const xla::PjRtLayout> layout,
|
||||
proxy_array->pjrt_layout());
|
||||
new_arrays.push_back(
|
||||
tsl::MakeRef<Array>(this, rpc_helper_, arrays[i]->dtype(),
|
||||
arrays[i]->shape(), std::move(new_sharding),
|
||||
ArrayHandle{result_handles[i]}, std::move(layout)));
|
||||
}
|
||||
return new_arrays;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -382,10 +382,12 @@ TEST_P(ClientTest, CopyArraysDefaultLayoutSuccess) {
|
|||
client_->CopyArrays(absl::MakeSpan(arrays), std::move(device_list),
|
||||
MemoryKind("mock"), ArrayCopySemantics::kAlwaysCopy));
|
||||
ASSERT_THAT(copied_arrays, SizeIs(2));
|
||||
EXPECT_EQ(llvm::cast<Array>(copied_arrays[0].get())->custom_layout(),
|
||||
nullptr);
|
||||
EXPECT_EQ(llvm::cast<Array>(copied_arrays[1].get())->custom_layout(),
|
||||
nullptr);
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<const xla::PjRtLayout> layout_1,
|
||||
copied_arrays[0].get()->pjrt_layout());
|
||||
EXPECT_EQ(layout_1, nullptr);
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<const xla::PjRtLayout> layout_2,
|
||||
copied_arrays[1].get()->pjrt_layout());
|
||||
EXPECT_EQ(layout_2, nullptr);
|
||||
}
|
||||
|
||||
TEST_P(ClientTest, CopyArraysCustomLayoutSuccess) {
|
||||
|
|
@ -418,12 +420,12 @@ TEST_P(ClientTest, CopyArraysCustomLayoutSuccess) {
|
|||
client_->CopyArrays(absl::MakeSpan(arrays), std::move(device_list),
|
||||
MemoryKind("mock"), ArrayCopySemantics::kAlwaysCopy));
|
||||
ASSERT_THAT(copied_arrays, SizeIs(2));
|
||||
EXPECT_EQ(
|
||||
llvm::cast<Array>(copied_arrays[0].get())->custom_layout()->ToString(),
|
||||
layout_1_->ToString());
|
||||
EXPECT_EQ(
|
||||
llvm::cast<Array>(copied_arrays[1].get())->custom_layout()->ToString(),
|
||||
layout_2_->ToString());
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<const xla::PjRtLayout> layout_1,
|
||||
copied_arrays[0].get()->pjrt_layout());
|
||||
EXPECT_EQ(layout_1->ToString(), layout_1_->ToString());
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<const xla::PjRtLayout> layout_2,
|
||||
copied_arrays[1].get()->pjrt_layout());
|
||||
EXPECT_EQ(layout_2->ToString(), layout_2_->ToString());
|
||||
}
|
||||
|
||||
TEST_P(ClientTest, GetDefaultDeviceAssignmentSuccess) {
|
||||
|
|
|
|||
|
|
@ -744,6 +744,15 @@ LoadedExecutable::Execute(absl::Span<xla::ifrt::ArrayRef> args,
|
|||
output_spec_cache_->Retrieve().has_value();
|
||||
|
||||
xla::ifrt::LoadedExecutable::ExecuteResult result;
|
||||
// TODO(hyeontaek): `GetOutputLayouts()` uses a concrete layout for a
|
||||
// default layout. This will change as proper IFRT layout support is fleshed
|
||||
// out. While the code here using `layouts` will automatically benefit from
|
||||
// the semantics change for `GetOutputLayouts()`, we would have a slightly
|
||||
// inconsistent state here until the change happens where output arrays use a
|
||||
// concrete layout for a default layout. This will not cause an issue for the
|
||||
// time being when the user always uses concrete layouts, but we would need to
|
||||
// resolve this issue before the user begins to use `nullptr` default layouts
|
||||
// without resolving it to a concrete layout.
|
||||
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>> layouts =
|
||||
GetOutputLayouts();
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user