Use Deserializer lambda for embedded thunks in DynamicSliceThunk

PiperOrigin-RevId: 826474606
This commit is contained in:
Eusebio Durán Montaña 2025-10-31 07:04:55 -07:00 committed by TensorFlower Gardener
parent 26d0882419
commit ecc2510eb0
4 changed files with 18 additions and 8 deletions

View File

@ -226,7 +226,6 @@ cc_library(
":sequential_thunk", ":sequential_thunk",
":thunk", ":thunk",
":thunk_proto_cc", ":thunk_proto_cc",
":thunk_proto_deserialization",
"//xla:literal", "//xla:literal",
"//xla:literal_util", "//xla:literal_util",
"//xla:shape_util", "//xla:shape_util",
@ -270,6 +269,7 @@ xla_test(
":gemm_thunk", ":gemm_thunk",
":sequential_thunk", ":sequential_thunk",
":thunk", ":thunk",
":thunk_proto_deserialization",
"//xla:shape_util", "//xla:shape_util",
"//xla:xla_data_proto_cc", "//xla:xla_data_proto_cc",
"//xla/ffi", "//xla/ffi",

View File

@ -41,7 +41,6 @@ limitations under the License.
#include "xla/backends/gpu/runtime/sequential_thunk.h" #include "xla/backends/gpu/runtime/sequential_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.h"
#include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/backends/gpu/runtime/thunk.pb.h"
#include "xla/backends/gpu/runtime/thunk_proto_deserialization.h"
#include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/evaluator/hlo_evaluator.h"
#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module.h"
#include "xla/literal.h" #include "xla/literal.h"
@ -617,7 +616,8 @@ absl::StatusOr<ThunkProto> DynamicSliceThunk::ToProto() const {
absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto( absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto, ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
absl::Span<const BufferAllocation> buffer_allocations, absl::Span<const BufferAllocation> buffer_allocations,
absl::Span<const BufferAllocation> fake_allocations) { absl::Span<const BufferAllocation> fake_allocations,
const Deserializer& deserializer) {
// offset_as_function_of_indvar_metadata // offset_as_function_of_indvar_metadata
std::optional<OffsetAsFunctionOfIndvarModulesMetadata> std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
offset_as_function_of_indvar_metadata; offset_as_function_of_indvar_metadata;
@ -677,9 +677,9 @@ absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
// embedded_thunk // embedded_thunk
std::vector<std::unique_ptr<Thunk>> embedded_thunks; std::vector<std::unique_ptr<Thunk>> embedded_thunks;
for (const auto& thunk_proto : proto.embedded_thunk().thunks()) { for (const auto& thunk_proto : proto.embedded_thunk().thunks()) {
TF_ASSIGN_OR_RETURN(auto thunk, TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> embedded_thunk,
DeserializeThunkProto(thunk_proto, fake_allocations)); deserializer(thunk_proto));
embedded_thunks.push_back(std::move(thunk)); embedded_thunks.push_back(std::move(embedded_thunk));
} }
// leave fake_allocations empty, because we manage their lifetime outside // leave fake_allocations empty, because we manage their lifetime outside

View File

@ -186,10 +186,12 @@ class DynamicSliceThunk : public Thunk {
// replaced during execution in `ExecuteOnStream` with the actual (dynamic) // replaced during execution in `ExecuteOnStream` with the actual (dynamic)
// slices. We have to create these outside of this method to manage their // slices. We have to create these outside of this method to manage their
// lifetime correctly. // lifetime correctly.
// `deserializer`: The deserializer is used to deserialize the embedded thunk.
static absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> FromProto( static absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> FromProto(
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto, ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
absl::Span<const BufferAllocation> buffer_allocations, absl::Span<const BufferAllocation> buffer_allocations,
absl::Span<const BufferAllocation> fake_allocations); absl::Span<const BufferAllocation> fake_allocations,
const Deserializer& deserializer);
std::optional<const OffsetAsFunctionOfIndvarModulesMetadata*> std::optional<const OffsetAsFunctionOfIndvarModulesMetadata*>
get_offset_function() const { get_offset_function() const {

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/gemm_thunk.h" #include "xla/backends/gpu/runtime/gemm_thunk.h"
#include "xla/backends/gpu/runtime/sequential_thunk.h" #include "xla/backends/gpu/runtime/sequential_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.h"
#include "xla/backends/gpu/runtime/thunk_proto_deserialization.h"
#include "xla/ffi/attribute_map.h" #include "xla/ffi/attribute_map.h"
#include "xla/ffi/ffi.h" #include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h" #include "xla/ffi/ffi_api.h"
@ -108,11 +109,18 @@ void CheckProtoRoundTrip(const DynamicSliceThunk& thunk,
BufferAllocation(i, arguments[i].value().allocation()->size(), 0)); BufferAllocation(i, arguments[i].value().allocation()->size(), 0));
} }
} }
Thunk::Deserializer deserializer =
[&buffer_allocations](const ThunkProto& thunk_proto)
-> absl::StatusOr<std::unique_ptr<Thunk>> {
return DeserializeThunkProto(thunk_proto, buffer_allocations);
};
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto thunk_from_proto, auto thunk_from_proto,
DynamicSliceThunk::FromProto(Thunk::ThunkInfo(), proto, DynamicSliceThunk::FromProto(Thunk::ThunkInfo(), proto,
/*buffer_allocations=*/buffer_allocations, /*buffer_allocations=*/buffer_allocations,
/*fake_allocations=*/fake_allocations_span)); /*fake_allocations=*/fake_allocations_span,
deserializer));
TF_ASSERT_OK_AND_ASSIGN(auto proto_roundtrip, thunk_from_proto->ToProto()); TF_ASSERT_OK_AND_ASSIGN(auto proto_roundtrip, thunk_from_proto->ToProto());
auto dynamic_slice_thunk_proto_roundtrip = auto dynamic_slice_thunk_proto_roundtrip =
proto_roundtrip.dynamic_slice_thunk(); proto_roundtrip.dynamic_slice_thunk();