Use Deserializer lambda for embedded thunks in DynamicSliceThunk

PiperOrigin-RevId: 826474606
This commit is contained in:
Eusebio Durán Montaña 2025-10-31 07:04:55 -07:00 committed by TensorFlower Gardener
parent 26d0882419
commit ecc2510eb0
4 changed files with 18 additions and 8 deletions

View File

@ -226,7 +226,6 @@ cc_library(
":sequential_thunk",
":thunk",
":thunk_proto_cc",
":thunk_proto_deserialization",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
@ -270,6 +269,7 @@ xla_test(
":gemm_thunk",
":sequential_thunk",
":thunk",
":thunk_proto_deserialization",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/ffi",

View File

@ -41,7 +41,6 @@ limitations under the License.
#include "xla/backends/gpu/runtime/sequential_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/backends/gpu/runtime/thunk.pb.h"
#include "xla/backends/gpu/runtime/thunk_proto_deserialization.h"
#include "xla/hlo/evaluator/hlo_evaluator.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/literal.h"
@ -617,7 +616,8 @@ absl::StatusOr<ThunkProto> DynamicSliceThunk::ToProto() const {
absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
absl::Span<const BufferAllocation> buffer_allocations,
absl::Span<const BufferAllocation> fake_allocations) {
absl::Span<const BufferAllocation> fake_allocations,
const Deserializer& deserializer) {
// offset_as_function_of_indvar_metadata
std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
offset_as_function_of_indvar_metadata;
@ -677,9 +677,9 @@ absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
// embedded_thunk
std::vector<std::unique_ptr<Thunk>> embedded_thunks;
for (const auto& thunk_proto : proto.embedded_thunk().thunks()) {
TF_ASSIGN_OR_RETURN(auto thunk,
DeserializeThunkProto(thunk_proto, fake_allocations));
embedded_thunks.push_back(std::move(thunk));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> embedded_thunk,
deserializer(thunk_proto));
embedded_thunks.push_back(std::move(embedded_thunk));
}
// leave fake_allocations empty, because we manage their lifetime outside

View File

@ -186,10 +186,12 @@ class DynamicSliceThunk : public Thunk {
// replaced during execution in `ExecuteOnStream` with the actual (dynamic)
// slices. We have to create these outside of this method to manage their
// lifetime correctly.
// `deserializer`: The deserializer is used to deserialize the embedded thunk.
static absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> FromProto(
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
absl::Span<const BufferAllocation> buffer_allocations,
absl::Span<const BufferAllocation> fake_allocations);
absl::Span<const BufferAllocation> fake_allocations,
const Deserializer& deserializer);
std::optional<const OffsetAsFunctionOfIndvarModulesMetadata*>
get_offset_function() const {

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/gemm_thunk.h"
#include "xla/backends/gpu/runtime/sequential_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/backends/gpu/runtime/thunk_proto_deserialization.h"
#include "xla/ffi/attribute_map.h"
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
@ -108,11 +109,18 @@ void CheckProtoRoundTrip(const DynamicSliceThunk& thunk,
BufferAllocation(i, arguments[i].value().allocation()->size(), 0));
}
}
Thunk::Deserializer deserializer =
[&buffer_allocations](const ThunkProto& thunk_proto)
-> absl::StatusOr<std::unique_ptr<Thunk>> {
return DeserializeThunkProto(thunk_proto, buffer_allocations);
};
TF_ASSERT_OK_AND_ASSIGN(
auto thunk_from_proto,
DynamicSliceThunk::FromProto(Thunk::ThunkInfo(), proto,
/*buffer_allocations=*/buffer_allocations,
/*fake_allocations=*/fake_allocations_span));
/*fake_allocations=*/fake_allocations_span,
deserializer));
TF_ASSERT_OK_AND_ASSIGN(auto proto_roundtrip, thunk_from_proto->ToProto());
auto dynamic_slice_thunk_proto_roundtrip =
proto_roundtrip.dynamic_slice_thunk();