From ecc2510eb0b75a0e7509cde891550f78f687485a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eusebio=20Dur=C3=A1n=20Monta=C3=B1a?= Date: Fri, 31 Oct 2025 07:04:55 -0700 Subject: [PATCH] Use Deserializer lambda for embedded thunks in `DynamicSliceThunk` PiperOrigin-RevId: 826474606 --- third_party/xla/xla/backends/gpu/runtime/BUILD | 2 +- .../xla/backends/gpu/runtime/dynamic_slice_thunk.cc | 10 +++++----- .../xla/xla/backends/gpu/runtime/dynamic_slice_thunk.h | 4 +++- .../backends/gpu/runtime/dynamic_slice_thunk_test.cc | 10 +++++++++- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 023bf01cad4..ca3c04404c9 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -226,7 +226,6 @@ cc_library( ":sequential_thunk", ":thunk", ":thunk_proto_cc", - ":thunk_proto_deserialization", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -270,6 +269,7 @@ xla_test( ":gemm_thunk", ":sequential_thunk", ":thunk", + ":thunk_proto_deserialization", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/ffi", diff --git a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc index 98763d288d4..b98d2d7d3ad 100644 --- a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc @@ -41,7 +41,6 @@ limitations under the License. #include "xla/backends/gpu/runtime/sequential_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.pb.h" -#include "xla/backends/gpu/runtime/thunk_proto_deserialization.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" @@ -617,7 +616,8 @@ absl::StatusOr DynamicSliceThunk::ToProto() const { absl::StatusOr> DynamicSliceThunk::FromProto( ThunkInfo thunk_info, const DynamicSliceThunkProto& proto, absl::Span buffer_allocations, - absl::Span fake_allocations) { + absl::Span fake_allocations, + const Deserializer& deserializer) { // offset_as_function_of_indvar_metadata std::optional offset_as_function_of_indvar_metadata; @@ -677,9 +677,9 @@ absl::StatusOr> DynamicSliceThunk::FromProto( // embedded_thunk std::vector> embedded_thunks; for (const auto& thunk_proto : proto.embedded_thunk().thunks()) { - TF_ASSIGN_OR_RETURN(auto thunk, - DeserializeThunkProto(thunk_proto, fake_allocations)); - embedded_thunks.push_back(std::move(thunk)); + TF_ASSIGN_OR_RETURN(std::unique_ptr embedded_thunk, + deserializer(thunk_proto)); + embedded_thunks.push_back(std::move(embedded_thunk)); } // leave fake_allocations empty, because we manage their lifetime outside diff --git a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.h b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.h index c0a08b33539..efb470aafc0 100644 --- a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.h @@ -186,10 +186,12 @@ class DynamicSliceThunk : public Thunk { // replaced during execution in `ExecuteOnStream` with the actual (dynamic) // slices. We have to create these outside of this method to manage their // lifetime correctly. + // `deserializer`: The deserializer is used to deserialize the embedded thunk. static absl::StatusOr> FromProto( ThunkInfo thunk_info, const DynamicSliceThunkProto& proto, absl::Span buffer_allocations, - absl::Span fake_allocations); + absl::Span fake_allocations, + const Deserializer& deserializer); std::optional get_offset_function() const { diff --git a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc index 6c09bb8a228..d3b71196e3b 100644 --- a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/gemm_thunk.h" #include "xla/backends/gpu/runtime/sequential_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" +#include "xla/backends/gpu/runtime/thunk_proto_deserialization.h" #include "xla/ffi/attribute_map.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" @@ -108,11 +109,18 @@ void CheckProtoRoundTrip(const DynamicSliceThunk& thunk, BufferAllocation(i, arguments[i].value().allocation()->size(), 0)); } } + + Thunk::Deserializer deserializer = + [&buffer_allocations](const ThunkProto& thunk_proto) + -> absl::StatusOr> { + return DeserializeThunkProto(thunk_proto, buffer_allocations); + }; TF_ASSERT_OK_AND_ASSIGN( auto thunk_from_proto, DynamicSliceThunk::FromProto(Thunk::ThunkInfo(), proto, /*buffer_allocations=*/buffer_allocations, - /*fake_allocations=*/fake_allocations_span)); + /*fake_allocations=*/fake_allocations_span, + deserializer)); TF_ASSERT_OK_AND_ASSIGN(auto proto_roundtrip, thunk_from_proto->ToProto()); auto dynamic_slice_thunk_proto_roundtrip = proto_roundtrip.dynamic_slice_thunk();