Add de/serializaton of fake_allocations in DynamicSliceThunk.

PiperOrigin-RevId: 826541399
This commit is contained in:
Aliia Khasanova 2025-10-31 10:30:18 -07:00 committed by TensorFlower Gardener
parent ecc2510eb0
commit 6ff7f9c87f
7 changed files with 41 additions and 18 deletions

View File

@ -2452,6 +2452,7 @@ tf_proto_library(
"//xla:xla_data_proto",
"//xla/core/host_offloading:host_offloading_executable_proto",
"//xla/service:buffer_assignment_proto",
"//xla/service:hlo_proto",
"//xla/service/gpu:backend_configs",
"//xla/service/gpu:gpu_conv_runner_proto",
"//xla/service/gpu:gpu_norm_runner_proto",
@ -2490,6 +2491,7 @@ cc_library(
":convolution_thunk",
":copy_thunk",
":cudnn_thunk",
":dynamic_slice_thunk",
":fft_thunk",
":gemm_thunk",
":gpublas_lt_matmul_thunk",

View File

@ -610,14 +610,19 @@ absl::StatusOr<ThunkProto> DynamicSliceThunk::ToProto() const {
->mutable_offset_as_function_of_indvar_modules_metadata(),
offset_as_function_of_indvar_metadata_->ToProto());
}
// fake_allocations
for (const auto& fake_allocation : fake_allocations_) {
*dynamic_slice_proto->add_fake_allocations() = fake_allocation.ToProto();
}
return proto;
}
absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
absl::Span<const BufferAllocation> buffer_allocations,
absl::Span<const BufferAllocation> fake_allocations,
const Deserializer& deserializer) {
const DeserializerWithCustomAllocations& deserializer) {
// offset_as_function_of_indvar_metadata
std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
offset_as_function_of_indvar_metadata;
@ -674,20 +679,24 @@ absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
}
}
// fake_allocations
std::vector<BufferAllocation> fake_allocations;
for (const auto& fake_allocation_proto : proto.fake_allocations()) {
fake_allocations.push_back(
BufferAllocation::FromProto(fake_allocation_proto));
}
// embedded_thunk
std::vector<std::unique_ptr<Thunk>> embedded_thunks;
for (const auto& thunk_proto : proto.embedded_thunk().thunks()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> embedded_thunk,
deserializer(thunk_proto));
deserializer(thunk_proto, fake_allocations));
embedded_thunks.push_back(std::move(embedded_thunk));
}
// leave fake_allocations empty, because we manage their lifetime outside
// of this function.
return std::make_unique<DynamicSliceThunk>(
thunk_info, std::make_unique<ThunkSequence>(std::move(embedded_thunks)),
std::move(arguments),
/*fake_allocations=*/std::vector<BufferAllocation>(), std::move(offsets),
std::move(arguments), std::move(fake_allocations), std::move(offsets),
std::move(orig_shapes), std::move(sliced_shapes),
std::move(offset_byte_sizes),
std::move(offset_as_function_of_indvar_metadata));

View File

@ -181,17 +181,11 @@ class DynamicSliceThunk : public Thunk {
// `buffer_allocations`: the actual buffer allocations; required to parse the
// `arguments` (BufferAllocation::Slice) -- the tensors that we are later
// slicing from.
// `fake_allocations`: The fake allocations that are used as
// placeholders during creation of the embedded thunk. These are being
// replaced during execution in `ExecuteOnStream` with the actual (dynamic)
// slices. We have to create these outside of this method to manage their
// lifetime correctly.
// `deserializer`: The deserializer is used to deserialize the embedded thunk.
static absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> FromProto(
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
absl::Span<const BufferAllocation> buffer_allocations,
absl::Span<const BufferAllocation> fake_allocations,
const Deserializer& deserializer);
const DeserializerWithCustomAllocations& deserializer);
std::optional<const OffsetAsFunctionOfIndvarModulesMetadata*>
get_offset_function() const {

View File

@ -110,16 +110,17 @@ void CheckProtoRoundTrip(const DynamicSliceThunk& thunk,
}
}
Thunk::Deserializer deserializer =
[&buffer_allocations](const ThunkProto& thunk_proto)
Thunk::DeserializerWithCustomAllocations deserializer =
[](const ThunkProto& thunk_proto,
absl::Span<const BufferAllocation> fake_allocations_span)
-> absl::StatusOr<std::unique_ptr<Thunk>> {
return DeserializeThunkProto(thunk_proto, buffer_allocations);
return DeserializeThunkProto(thunk_proto, fake_allocations_span);
};
TF_ASSERT_OK_AND_ASSIGN(
auto thunk_from_proto,
DynamicSliceThunk::FromProto(Thunk::ThunkInfo(), proto,
/*buffer_allocations=*/buffer_allocations,
/*fake_allocations=*/fake_allocations_span,
deserializer));
TF_ASSERT_OK_AND_ASSIGN(auto proto_roundtrip, thunk_from_proto->ToProto());
auto dynamic_slice_thunk_proto_roundtrip =

View File

@ -571,6 +571,10 @@ class Thunk {
absl::AnyInvocable<absl::StatusOr<std::unique_ptr<Thunk>>(
const ThunkProto&) const>;
using DeserializerWithCustomAllocations =
absl::AnyInvocable<absl::StatusOr<std::unique_ptr<Thunk>>(
const ThunkProto&, absl::Span<const BufferAllocation>) const>;
void add_control_predecessor(const Thunk* control_predecessor) {
control_predecessors_.push_back(control_predecessor);
}

View File

@ -24,6 +24,7 @@ import "xla/service/buffer_assignment.proto";
import "xla/service/gpu/gpu_conv_runner.proto";
import "xla/service/gpu/gpu_norm_runner.proto";
import "xla/service/gpu/launch_dimensions.proto";
import "xla/service/hlo.proto";
import "xla/stream_executor/gpu/gpu_blas_lt.proto";
import "xla/stream_executor/gpu/tma_metadata.proto";
import "xla/stream_executor/launch_dim.proto";
@ -167,6 +168,7 @@ message DynamicSliceThunkProto {
repeated OptionalInt64Proto offset_byte_sizes = 6;
optional OffsetAsFunctionOfIndvarModulesMetadataProto
offset_as_function_of_indvar_modules_metadata = 7;
repeated BufferAllocationProto fake_allocations = 8;
}
message MemzeroThunkProto {

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/convolution_thunk.h"
#include "xla/backends/gpu/runtime/copy_thunk.h"
#include "xla/backends/gpu/runtime/cudnn_thunk.h"
#include "xla/backends/gpu/runtime/dynamic_slice_thunk.h"
#include "xla/backends/gpu/runtime/fft_thunk.h"
#include "xla/backends/gpu/runtime/gemm_thunk.h"
#include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h"
@ -167,6 +168,16 @@ absl::StatusOr<std::unique_ptr<Thunk>> DeserializeThunkProto(
return Memset32BitValueThunk::FromProto(
std::move(thunk_info), thunk_proto.memset32bit_value_thunk(),
buffer_allocations);
case ThunkProto::kDynamicSliceThunk: {
auto deserializer =
[](const ThunkProto& thunk_proto,
absl::Span<const BufferAllocation> custom_allocations) {
return DeserializeThunkProto(thunk_proto, custom_allocations);
};
return DynamicSliceThunk::FromProto(std::move(thunk_info),
thunk_proto.dynamic_slice_thunk(),
buffer_allocations, deserializer);
}
default:
std::optional<absl::string_view> unsupported_thunk_type =
GetStoredThunkTypeName(thunk_proto);