From 6ff7f9c87f033e8011d8214767ab5c4de04ed2c7 Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Fri, 31 Oct 2025 10:30:18 -0700 Subject: [PATCH] Add de/serializaton of `fake_allocations` in `DynamicSliceThunk`. PiperOrigin-RevId: 826541399 --- .../xla/xla/backends/gpu/runtime/BUILD | 2 ++ .../gpu/runtime/dynamic_slice_thunk.cc | 23 +++++++++++++------ .../gpu/runtime/dynamic_slice_thunk.h | 8 +------ .../gpu/runtime/dynamic_slice_thunk_test.cc | 9 ++++---- .../xla/xla/backends/gpu/runtime/thunk.h | 4 ++++ .../xla/xla/backends/gpu/runtime/thunk.proto | 2 ++ .../runtime/thunk_proto_deserialization.cc | 11 +++++++++ 7 files changed, 41 insertions(+), 18 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index ca3c04404c9..5d08e43aa7c 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -2452,6 +2452,7 @@ tf_proto_library( "//xla:xla_data_proto", "//xla/core/host_offloading:host_offloading_executable_proto", "//xla/service:buffer_assignment_proto", + "//xla/service:hlo_proto", "//xla/service/gpu:backend_configs", "//xla/service/gpu:gpu_conv_runner_proto", "//xla/service/gpu:gpu_norm_runner_proto", @@ -2490,6 +2491,7 @@ cc_library( ":convolution_thunk", ":copy_thunk", ":cudnn_thunk", + ":dynamic_slice_thunk", ":fft_thunk", ":gemm_thunk", ":gpublas_lt_matmul_thunk", diff --git a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc index b98d2d7d3ad..6f37c3cd13d 100644 --- a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc @@ -610,14 +610,19 @@ absl::StatusOr DynamicSliceThunk::ToProto() const { ->mutable_offset_as_function_of_indvar_modules_metadata(), offset_as_function_of_indvar_metadata_->ToProto()); } + + // fake_allocations + for (const auto& fake_allocation : fake_allocations_) { + *dynamic_slice_proto->add_fake_allocations() = fake_allocation.ToProto(); + } + return proto; } absl::StatusOr> DynamicSliceThunk::FromProto( ThunkInfo thunk_info, const DynamicSliceThunkProto& proto, absl::Span buffer_allocations, - absl::Span fake_allocations, - const Deserializer& deserializer) { + const DeserializerWithCustomAllocations& deserializer) { // offset_as_function_of_indvar_metadata std::optional offset_as_function_of_indvar_metadata; @@ -674,20 +679,24 @@ absl::StatusOr> DynamicSliceThunk::FromProto( } } + // fake_allocations + std::vector fake_allocations; + for (const auto& fake_allocation_proto : proto.fake_allocations()) { + fake_allocations.push_back( + BufferAllocation::FromProto(fake_allocation_proto)); + } + // embedded_thunk std::vector> embedded_thunks; for (const auto& thunk_proto : proto.embedded_thunk().thunks()) { TF_ASSIGN_OR_RETURN(std::unique_ptr embedded_thunk, - deserializer(thunk_proto)); + deserializer(thunk_proto, fake_allocations)); embedded_thunks.push_back(std::move(embedded_thunk)); } - // leave fake_allocations empty, because we manage their lifetime outside - // of this function. return std::make_unique( thunk_info, std::make_unique(std::move(embedded_thunks)), - std::move(arguments), - /*fake_allocations=*/std::vector(), std::move(offsets), + std::move(arguments), std::move(fake_allocations), std::move(offsets), std::move(orig_shapes), std::move(sliced_shapes), std::move(offset_byte_sizes), std::move(offset_as_function_of_indvar_metadata)); diff --git a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.h b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.h index efb470aafc0..155dd32a8aa 100644 --- a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.h @@ -181,17 +181,11 @@ class DynamicSliceThunk : public Thunk { // `buffer_allocations`: the actual buffer allocations; required to parse the // `arguments` (BufferAllocation::Slice) -- the tensors that we are later // slicing from. - // `fake_allocations`: The fake allocations that are used as - // placeholders during creation of the embedded thunk. These are being - // replaced during execution in `ExecuteOnStream` with the actual (dynamic) - // slices. We have to create these outside of this method to manage their - // lifetime correctly. // `deserializer`: The deserializer is used to deserialize the embedded thunk. static absl::StatusOr> FromProto( ThunkInfo thunk_info, const DynamicSliceThunkProto& proto, absl::Span buffer_allocations, - absl::Span fake_allocations, - const Deserializer& deserializer); + const DeserializerWithCustomAllocations& deserializer); std::optional get_offset_function() const { diff --git a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc index d3b71196e3b..f83e6666eb5 100644 --- a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc @@ -110,16 +110,17 @@ void CheckProtoRoundTrip(const DynamicSliceThunk& thunk, } } - Thunk::Deserializer deserializer = - [&buffer_allocations](const ThunkProto& thunk_proto) + Thunk::DeserializerWithCustomAllocations deserializer = + [](const ThunkProto& thunk_proto, + absl::Span fake_allocations_span) -> absl::StatusOr> { - return DeserializeThunkProto(thunk_proto, buffer_allocations); + return DeserializeThunkProto(thunk_proto, fake_allocations_span); }; + TF_ASSERT_OK_AND_ASSIGN( auto thunk_from_proto, DynamicSliceThunk::FromProto(Thunk::ThunkInfo(), proto, /*buffer_allocations=*/buffer_allocations, - /*fake_allocations=*/fake_allocations_span, deserializer)); TF_ASSERT_OK_AND_ASSIGN(auto proto_roundtrip, thunk_from_proto->ToProto()); auto dynamic_slice_thunk_proto_roundtrip = diff --git a/third_party/xla/xla/backends/gpu/runtime/thunk.h b/third_party/xla/xla/backends/gpu/runtime/thunk.h index 70e8c0947f6..7f2287d4cd1 100644 --- a/third_party/xla/xla/backends/gpu/runtime/thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/thunk.h @@ -571,6 +571,10 @@ class Thunk { absl::AnyInvocable>( const ThunkProto&) const>; + using DeserializerWithCustomAllocations = + absl::AnyInvocable>( + const ThunkProto&, absl::Span) const>; + void add_control_predecessor(const Thunk* control_predecessor) { control_predecessors_.push_back(control_predecessor); } diff --git a/third_party/xla/xla/backends/gpu/runtime/thunk.proto b/third_party/xla/xla/backends/gpu/runtime/thunk.proto index 7543054d6c0..156ae1120e4 100644 --- a/third_party/xla/xla/backends/gpu/runtime/thunk.proto +++ b/third_party/xla/xla/backends/gpu/runtime/thunk.proto @@ -24,6 +24,7 @@ import "xla/service/buffer_assignment.proto"; import "xla/service/gpu/gpu_conv_runner.proto"; import "xla/service/gpu/gpu_norm_runner.proto"; import "xla/service/gpu/launch_dimensions.proto"; +import "xla/service/hlo.proto"; import "xla/stream_executor/gpu/gpu_blas_lt.proto"; import "xla/stream_executor/gpu/tma_metadata.proto"; import "xla/stream_executor/launch_dim.proto"; @@ -167,6 +168,7 @@ message DynamicSliceThunkProto { repeated OptionalInt64Proto offset_byte_sizes = 6; optional OffsetAsFunctionOfIndvarModulesMetadataProto offset_as_function_of_indvar_modules_metadata = 7; + repeated BufferAllocationProto fake_allocations = 8; } message MemzeroThunkProto { diff --git a/third_party/xla/xla/backends/gpu/runtime/thunk_proto_deserialization.cc b/third_party/xla/xla/backends/gpu/runtime/thunk_proto_deserialization.cc index c2ec6a97527..b47370eb955 100644 --- a/third_party/xla/xla/backends/gpu/runtime/thunk_proto_deserialization.cc +++ b/third_party/xla/xla/backends/gpu/runtime/thunk_proto_deserialization.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/convolution_thunk.h" #include "xla/backends/gpu/runtime/copy_thunk.h" #include "xla/backends/gpu/runtime/cudnn_thunk.h" +#include "xla/backends/gpu/runtime/dynamic_slice_thunk.h" #include "xla/backends/gpu/runtime/fft_thunk.h" #include "xla/backends/gpu/runtime/gemm_thunk.h" #include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h" @@ -167,6 +168,16 @@ absl::StatusOr> DeserializeThunkProto( return Memset32BitValueThunk::FromProto( std::move(thunk_info), thunk_proto.memset32bit_value_thunk(), buffer_allocations); + case ThunkProto::kDynamicSliceThunk: { + auto deserializer = + [](const ThunkProto& thunk_proto, + absl::Span custom_allocations) { + return DeserializeThunkProto(thunk_proto, custom_allocations); + }; + return DynamicSliceThunk::FromProto(std::move(thunk_info), + thunk_proto.dynamic_slice_thunk(), + buffer_allocations, deserializer); + } default: std::optional unsupported_thunk_type = GetStoredThunkTypeName(thunk_proto);