Use std::vector<BufferAllocation> instead of std::vector<std::unique_ptr<BufferAllocation>> in DynamicSliceThunk.

`BufferAllocation::Slice` stores a raw pointer to the corresponding `BufferAllocation`. Now we keep the embedded thunk allocations alive by stroing unique_ptrs in the wrapping DynamicSliceThunk. The current design makes it hard to reuse the existing infrastructure, specifically to serialize `DynamicSliceThunk`. To address this, I'm changing fake_allocations to be  `std::vector<BufferAllocation>`.

The move constructor `std::vector::vector(std::vector&&)` is guaranteed to have constant time complexity and therefore it steals the internal data buffer from the source vector. This infers that the pointers to allocations are kept stable as long as:
* we preallocate the vector size
* we never copy the vector, but move

To make it safer for later usage, we can explicitely prohibid BufferAllocation to be  copyable/moveable. I'm going to do this in the following cl.

PiperOrigin-RevId: 826440060
This commit is contained in:
Aliia Khasanova 2025-10-31 04:58:35 -07:00 committed by TensorFlower Gardener
parent 3326b0221f
commit add489fd8d
8 changed files with 142 additions and 176 deletions

View File

@ -613,7 +613,7 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
// different offset by creating new fake allocations so each operand will
// have a different buffer index. The slices can thus always start at offset
// 0. DynamicSliceThunk will take care of the offset adjustment.
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations(4);
std::vector<BufferAllocation> fake_allocations(4, {0, 0, 0});
if (fusion.shape().IsArray()) {
TF_ASSIGN_OR_RETURN(
output, GetResultSlice(buffer_assignment, adaptor, fusion, custom_call,
@ -645,10 +645,10 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
offset_buffer_indices, orig_shapes, sliced_shapes, offset_byte_sizes,
extracted_offset_modules, arg_idx, can_compute_indvar_on_host, while_op,
indvar_idx, inlined_module));
fake_allocations[arg_idx] = std::make_unique<BufferAllocation>(
fake_allocations[arg_idx] = BufferAllocation(
/*index=*/arg_idx, workspace->size(), /*color=*/0);
slice_workspace_fake = BufferAllocation::Slice(
fake_allocations[arg_idx].get(), 0, workspace->size());
slice_workspace_fake = BufferAllocation::Slice(&fake_allocations[arg_idx],
0, workspace->size());
}
if (absl::c_all_of(slice_instrs, [&](auto slice_instr) {
@ -676,27 +676,27 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
unsigned fake_arg_idx = 0;
int64_t lhs_byte_size =
ShapeUtil::ByteSizeOf(custom_call.operand(fake_arg_idx)->shape());
fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
fake_allocations[fake_arg_idx] = BufferAllocation(
/*index=*/fake_arg_idx, lhs_byte_size, /*color=*/0);
BufferAllocation::Slice slice_lhs_fake(fake_allocations[fake_arg_idx].get(),
0, lhs_byte_size);
BufferAllocation::Slice slice_lhs_fake(&fake_allocations[fake_arg_idx], 0,
lhs_byte_size);
fake_arg_idx++;
int64_t rhs_byte_size =
ShapeUtil::ByteSizeOf(custom_call.operand(fake_arg_idx)->shape());
fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
fake_allocations[fake_arg_idx] = BufferAllocation(
/*index=*/fake_arg_idx, rhs_byte_size, /*color=*/0);
BufferAllocation::Slice slice_rhs_fake(fake_allocations[fake_arg_idx].get(),
0, rhs_byte_size);
BufferAllocation::Slice slice_rhs_fake(&fake_allocations[fake_arg_idx], 0,
rhs_byte_size);
fake_arg_idx++;
int64_t out_fake_byte_size = ShapeUtil::ByteSizeOf(
custom_call.shape().IsArray() ? custom_call.shape()
: custom_call.shape().tuple_shapes(0));
fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
fake_allocations[fake_arg_idx] = BufferAllocation(
/*index=*/fake_arg_idx, out_fake_byte_size, /*color=*/0);
BufferAllocation::Slice slice_out_fake(fake_allocations[fake_arg_idx].get(),
0, out_fake_byte_size);
BufferAllocation::Slice slice_out_fake(&fake_allocations[fake_arg_idx], 0,
out_fake_byte_size);
ThunkSequence seq;
seq.emplace_back(std::make_unique<GemmThunk>(
thunk_info, std::move(config), slice_lhs_fake, slice_rhs_fake,
@ -962,7 +962,7 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(
ir_emitter_context.platform_name());
};
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations(num_args);
std::vector<BufferAllocation> fake_allocations(num_args, {0, 0, 0});
if (absl::c_any_of(slice_instrs, IsDynamicSliceOrDynamicUpdateSlice)) {
// Creating embedded custom call thunk.
unsigned fake_arg_idx = 0;
@ -982,10 +982,10 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(
}
int64_t operand_byte_size = ShapeUtil::ByteSizeOf(subshape);
fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
fake_allocations[fake_arg_idx] = BufferAllocation(
/*index=*/fake_arg_idx, operand_byte_size, /*color=*/0);
BufferAllocation::Slice fake_slice(
fake_allocations[fake_arg_idx].get(), 0, operand_byte_size);
BufferAllocation::Slice fake_slice(&fake_allocations[fake_arg_idx],
0, operand_byte_size);
fake_arg_idx++;
fake_operands.push_back(ShapedSlice{fake_slice, subshape});
@ -1007,10 +1007,10 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(
}
int64_t result_byte_size = ShapeUtil::ByteSizeOf(subshape);
fake_allocations[fake_arg_idx] = std::make_unique<BufferAllocation>(
fake_allocations[fake_arg_idx] = BufferAllocation(
/*index=*/fake_arg_idx, result_byte_size, /*color=*/0);
BufferAllocation::Slice fake_slice(
fake_allocations[fake_arg_idx].get(), 0, result_byte_size);
BufferAllocation::Slice fake_slice(&fake_allocations[fake_arg_idx], 0,
result_byte_size);
fake_arg_idx++;
fake_results.push_back(ShapedSlice{fake_slice, subshape});
@ -1065,7 +1065,7 @@ using Slices = std::vector<Slice>;
// fake_arguments: the fake slices of the inputs/outputs of the hero
// instruction, when the slicing is dynamic.
struct SliceDataForCollectives {
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations;
std::vector<BufferAllocation> fake_allocations;
std::vector<HloInstruction*> slice_instrs;
Slices arguments, fake_arguments;
std::vector<std::optional<std::vector<DynamicSliceThunk::Offset>>>
@ -1076,7 +1076,7 @@ struct SliceDataForCollectives {
std::unique_ptr<HloModule> init_module, update_module;
bool isDynamic, can_compute_indvar_on_host;
explicit SliceDataForCollectives(int num_args)
: fake_allocations(num_args),
: fake_allocations(num_args, {0, 0, 0}),
slice_instrs(num_args),
arguments(num_args, std::nullopt),
fake_arguments(num_args, std::nullopt),
@ -1197,11 +1197,10 @@ CollectSliceArgumentMetadataForCollectives(
unsigned fake_arg_idx = 0;
for (HloInstruction* operand : instr->operands()) {
int64_t operand_byte_size = ShapeUtil::ByteSizeOf(operand->shape());
slice_data.fake_allocations[fake_arg_idx] =
std::make_unique<BufferAllocation>(
slice_data.fake_allocations[fake_arg_idx] = BufferAllocation(
/*index=*/fake_arg_idx, operand_byte_size, /*color=*/0);
BufferAllocation::Slice fake_slice(
/*allocation=*/slice_data.fake_allocations[fake_arg_idx].get(),
/*allocation=*/&slice_data.fake_allocations[fake_arg_idx],
/*offset=*/0,
/*size=*/operand_byte_size);
slice_data.fake_arguments[fake_arg_idx] = fake_slice;
@ -1217,12 +1216,11 @@ CollectSliceArgumentMetadataForCollectives(
}
for (const HloInstruction* user : collective_results) {
int64_t out_fake_byte_size = ShapeUtil::ByteSizeOf(user->shape());
slice_data.fake_allocations[fake_arg_idx] =
std::make_unique<BufferAllocation>(
slice_data.fake_allocations[fake_arg_idx] = BufferAllocation(
/*index=*/fake_arg_idx, /*size=*/out_fake_byte_size,
/*color=*/0);
BufferAllocation::Slice fake_slice(
/*allocation=*/slice_data.fake_allocations[fake_arg_idx].get(),
/*allocation=*/&slice_data.fake_allocations[fake_arg_idx],
/*offset=*/0, /*size=*/out_fake_byte_size);
slice_data.fake_arguments[fake_arg_idx] = fake_slice;
fake_arg_idx++;

View File

@ -2377,7 +2377,7 @@ CommandBufferCmd::BufferUseVector CollectiveBroadcastCmd::buffers() const {
DynamicSliceFusionCmd::DynamicSliceFusionCmd(
CommandBufferCmdExecutor embedded_commands,
std::vector<std::optional<BufferAllocation::Slice>> arguments,
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations,
std::vector<BufferAllocation> fake_allocations,
std::vector<std::optional<std::vector<DynamicSliceThunk::Offset>>> offsets,
std::vector<std::optional<Shape>> orig_shapes,
std::vector<std::optional<Shape>> sliced_shapes,

View File

@ -1217,7 +1217,7 @@ class DynamicSliceFusionCmd : public CommandBufferCmd {
DynamicSliceFusionCmd(
CommandBufferCmdExecutor embedded_commands,
std::vector<std::optional<BufferAllocation::Slice>> arguments,
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations_,
std::vector<BufferAllocation> fake_allocations,
std::vector<std::optional<std::vector<DynamicSliceThunk::Offset>>>
offsets,
std::vector<std::optional<Shape>> orig_shapes,
@ -1252,7 +1252,7 @@ class DynamicSliceFusionCmd : public CommandBufferCmd {
private:
CommandBufferCmdExecutor embedded_commands_;
std::vector<DynamicSliceThunk::SliceDef> slices_;
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations_;
std::vector<BufferAllocation> fake_allocations_;
// Pinned host memory for transferring offset values from device to host.
absl::Mutex mutex_;

View File

@ -197,10 +197,10 @@ static absl::StatusOr<Command> Convert(
ConvertToCommands(thunk.get_embedded_thunk()->thunks(), options));
auto& thunk_fake_allocations = thunk.get_fake_allocations();
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations;
std::vector<BufferAllocation> fake_allocations;
for (auto it = thunk_fake_allocations.begin();
it != thunk_fake_allocations.end(); ++it) {
fake_allocations.push_back(std::make_unique<BufferAllocation>(**it));
fake_allocations.push_back(BufferAllocation(*it));
}
return std::make_unique<DynamicSliceFusionCmd>(
std::move(embedded_cmds), thunk.get_arguments(),

View File

@ -949,24 +949,22 @@ TEST(CommandBufferThunkTest, DISABLED_DynamicSliceFusionCmd) {
TF_ASSERT_OK(stream->MemZero(&workspace, 1024 * 1024));
// Prepare buffer allocations for recording command buffer.
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations(4);
fake_allocations[0] = std::make_unique<BufferAllocation>(
std::vector<BufferAllocation> fake_allocations;
fake_allocations.reserve(4);
fake_allocations.emplace_back(
/*index=*/0, fake_lhs_length, /*color=*/0);
fake_allocations[1] =
std::make_unique<BufferAllocation>(/*index=*/1, rhs_length, /*color=*/0);
fake_allocations[2] =
std::make_unique<BufferAllocation>(/*index=*/2, out_length,
fake_allocations.emplace_back(
/*index=*/1, rhs_length, /*color=*/0);
fake_allocations.emplace_back(/*index=*/2, out_length,
/*color=*/0);
fake_allocations[3] =
std::make_unique<BufferAllocation>(/*index=*/3, 1024 * 1024,
fake_allocations.emplace_back(/*index=*/3, 1024 * 1024,
/*color=*/0);
BufferAllocation::Slice fake_slice_lhs(fake_allocations[0].get(), 0,
BufferAllocation::Slice fake_slice_lhs(&fake_allocations[0], 0,
fake_lhs_length);
BufferAllocation::Slice slice_rhs(fake_allocations[1].get(), 0, rhs_length);
BufferAllocation::Slice slice_out(fake_allocations[2].get(), 0, out_length);
BufferAllocation::Slice slice_workspace(fake_allocations[3].get(), 0,
1024 * 1024);
BufferAllocation::Slice slice_rhs(&fake_allocations[1], 0, rhs_length);
BufferAllocation::Slice slice_out(&fake_allocations[2], 0, out_length);
BufferAllocation::Slice slice_workspace(&fake_allocations[3], 0, 1024 * 1024);
auto config = GemmConfig::For(
ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), {}, {1},
ShapeUtil::MakeShape(PrimitiveType::F32, {4, 3}), {}, {0},

View File

@ -175,7 +175,7 @@ std::string DynamicSliceThunk::SliceDef::ToString() const {
DynamicSliceThunk::DynamicSliceThunk(
ThunkInfo thunk_info, std::unique_ptr<ThunkSequence> embedded_thunk,
std::vector<std::optional<BufferAllocation::Slice>> arguments,
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations,
std::vector<BufferAllocation> fake_allocations,
std::vector<std::optional<std::vector<Offset>>> offsets,
std::vector<std::optional<Shape>> orig_shapes,
std::vector<std::optional<Shape>> sliced_shapes,
@ -687,8 +687,8 @@ absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
return std::make_unique<DynamicSliceThunk>(
thunk_info, std::make_unique<ThunkSequence>(std::move(embedded_thunks)),
std::move(arguments),
/*fake_allocations=*/std::vector<std::unique_ptr<BufferAllocation>>(),
std::move(offsets), std::move(orig_shapes), std::move(sliced_shapes),
/*fake_allocations=*/std::vector<BufferAllocation>(), std::move(offsets),
std::move(orig_shapes), std::move(sliced_shapes),
std::move(offset_byte_sizes),
std::move(offset_as_function_of_indvar_metadata));
}

View File

@ -114,7 +114,7 @@ class DynamicSliceThunk : public Thunk {
DynamicSliceThunk(
ThunkInfo thunk_info, std::unique_ptr<ThunkSequence> embedded_thunk,
std::vector<std::optional<BufferAllocation::Slice>> arguments,
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations,
std::vector<BufferAllocation> fake_allocations,
std::vector<std::optional<std::vector<Offset>>> offsets,
std::vector<std::optional<Shape>> orig_shapes,
std::vector<std::optional<Shape>> sliced_shapes,
@ -150,8 +150,7 @@ class DynamicSliceThunk : public Thunk {
return arguments_;
}
const std::vector<std::unique_ptr<BufferAllocation>>& get_fake_allocations()
const {
const std::vector<BufferAllocation>& get_fake_allocations() const {
return fake_allocations_;
}
@ -204,7 +203,7 @@ class DynamicSliceThunk : public Thunk {
private:
std::unique_ptr<SequentialThunk> embedded_thunk_;
std::vector<std::optional<BufferAllocation::Slice>> arguments_;
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations_;
std::vector<BufferAllocation> fake_allocations_;
std::vector<std::optional<std::vector<Offset>>> offsets_;
std::vector<std::optional<Shape>> orig_shapes_;
std::vector<std::optional<Shape>> sliced_shapes_;

View File

@ -161,28 +161,24 @@ absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> CreateSlicedGemmThunk(
int64_t out_length = sizeof(float) * 1 * 1;
int64_t offset_length = sizeof(int64_t);
// Preparing buffer allocation slices for thunk creations.
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations;
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/0, rhs_length, /*color=*/0));
BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0,
std::vector<BufferAllocation> fake_allocations;
fake_allocations.reserve(4);
fake_allocations.emplace_back(/*index=*/0, rhs_length, /*color=*/0);
BufferAllocation::Slice slice_lhs_fake(&fake_allocations.back(), 0,
rhs_length);
auto alloc_lhs =
std::make_unique<BufferAllocation>(/*index=*/0, lhs_length, /*color=*/0);
BufferAllocation::Slice slice_lhs(alloc_lhs.get(), 0, lhs_length);
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/1, rhs_length, /*color=*/0));
BufferAllocation::Slice slice_rhs(fake_allocations.back().get(), 0,
rhs_length);
fake_allocations.emplace_back(/*index=*/1, rhs_length, /*color=*/0);
BufferAllocation::Slice slice_rhs(&fake_allocations.back(), 0, rhs_length);
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/2, out_length, /*color=*/0));
BufferAllocation::Slice slice_out(fake_allocations.back().get(), 0,
out_length);
fake_allocations.emplace_back(/*index=*/2, out_length, /*color=*/0);
BufferAllocation::Slice slice_out(&fake_allocations.back(), 0, out_length);
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/3, 1024 * 1024, /*color=*/0));
BufferAllocation::Slice slice_workspace(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/3, 1024 * 1024, /*color=*/0);
BufferAllocation::Slice slice_workspace(&fake_allocations.back(), 0,
1024 * 1024);
auto alloc_lhs_offset_0 = std::make_unique<BufferAllocation>(
@ -323,14 +319,13 @@ CreateMultipleSlicedOperandsGemmThunk(
int64_t offset_length = sizeof(int64_t);
int64_t slice_length = sizeof(float) * 3;
// Preparing buffer allocation slices for thunk creations.
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations;
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/0, slice_length, /*color=*/0));
BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0,
std::vector<BufferAllocation> fake_allocations;
fake_allocations.reserve(4);
fake_allocations.emplace_back(/*index=*/0, slice_length, /*color=*/0);
BufferAllocation::Slice slice_lhs_fake(&fake_allocations.back(), 0,
slice_length);
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/1, slice_length, /*color=*/0));
BufferAllocation::Slice slice_rhs_fake(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/1, slice_length, /*color=*/0);
BufferAllocation::Slice slice_rhs_fake(&fake_allocations.back(), 0,
slice_length);
auto alloc_lhs =
std::make_unique<BufferAllocation>(/*index=*/0, length, /*color=*/0);
@ -338,13 +333,10 @@ CreateMultipleSlicedOperandsGemmThunk(
auto alloc_rhs =
std::make_unique<BufferAllocation>(/*index=*/1, length, /*color=*/0);
BufferAllocation::Slice slice_rhs(alloc_rhs.get(), 0, length);
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/2, out_length, /*color=*/0));
BufferAllocation::Slice slice_out(fake_allocations.back().get(), 0,
out_length);
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/3, 1024 * 1024, /*color=*/0));
BufferAllocation::Slice slice_workspace(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/2, out_length, /*color=*/0);
BufferAllocation::Slice slice_out(&fake_allocations.back(), 0, out_length);
fake_allocations.emplace_back(/*index=*/3, 1024 * 1024, /*color=*/0);
BufferAllocation::Slice slice_workspace(&fake_allocations.back(), 0,
1024 * 1024);
auto alloc_lhs_offset_0 = std::make_unique<BufferAllocation>(
/*index=*/4, offset_length, /*color=*/0);
@ -540,21 +532,19 @@ TEST_F(DynamicSliceThunkTest, SlicedMemcpy) {
// Prepare embedded and dynamic slice thunks.
// Preparing buffer allocation slices for thunk creations.
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations(2);
std::vector<BufferAllocation> fake_allocations;
fake_allocations.reserve(2);
// Fake slices for embedded thunk creation.
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/0, slice_length, /*color=*/0));
BufferAllocation::Slice slice_src_fake(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/0, slice_length, /*color=*/0);
BufferAllocation::Slice slice_src_fake(&fake_allocations.back(), 0,
slice_length);
BufferAllocation alloc_src(/*index=*/0, src_length, /*color=*/0);
BufferAllocation::Slice slice_src(&alloc_src, 0, src_length);
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/1, dst_length, /*color=*/0));
BufferAllocation::Slice slice_dst(fake_allocations.back().get(), 0,
dst_length);
fake_allocations.emplace_back(/*index=*/1, dst_length, /*color=*/0);
BufferAllocation::Slice slice_dst(&fake_allocations.back(), 0, dst_length);
BufferAllocation alloc_offset_0(/*index=*/2, offset_length, /*color=*/0);
BufferAllocation::Slice slice_offset_0(&alloc_offset_0, 0, offset_length);
@ -678,17 +668,16 @@ TEST_F(DynamicSliceThunkTest, SlicedOutputMemcpy) {
// Prepare embedded and dynamic slice thunks.
// Preparing buffer allocation slices for thunk creations.
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations(2);
std::vector<BufferAllocation> fake_allocations;
fake_allocations.reserve(2);
// Fake slices for embedded thunk creation.
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/0, slice_length, /*color=*/0));
BufferAllocation::Slice slice_src_fake(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/0, slice_length, /*color=*/0);
BufferAllocation::Slice slice_src_fake(&fake_allocations.back(), 0,
slice_length);
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/1, slice_length, /*color=*/0));
BufferAllocation::Slice slice_dst_fake(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/1, slice_length, /*color=*/0);
BufferAllocation::Slice slice_dst_fake(&fake_allocations.back(), 0,
slice_length);
BufferAllocation alloc_src(/*index=*/0, src_length, /*color=*/0);
@ -868,22 +857,19 @@ CreateSlicedGemmArbitraryArgumentOrderThunk(
int64_t offset_length = sizeof(int64_t);
// Preparing buffer allocation slices for thunk creations.
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations;
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/0, rhs_length, /*color=*/0));
BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0,
std::vector<BufferAllocation> fake_allocations;
fake_allocations.reserve(4);
fake_allocations.emplace_back(/*index=*/0, rhs_length, /*color=*/0);
BufferAllocation::Slice slice_lhs_fake(&fake_allocations.back(), 0,
rhs_length);
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/1, rhs_length, /*color=*/0));
BufferAllocation::Slice slice_rhs_fake(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/1, rhs_length, /*color=*/0);
BufferAllocation::Slice slice_rhs_fake(&fake_allocations.back(), 0,
rhs_length);
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/2, out_length, /*color=*/0));
BufferAllocation::Slice slice_out_fake(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/2, out_length, /*color=*/0);
BufferAllocation::Slice slice_out_fake(&fake_allocations.back(), 0,
out_length);
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/3, 1024 * 1024, /*color=*/0));
BufferAllocation::Slice slice_workspace_fake(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/3, 1024 * 1024, /*color=*/0);
BufferAllocation::Slice slice_workspace_fake(&fake_allocations.back(), 0,
1024 * 1024);
auto alloc_lhs =
@ -1043,22 +1029,19 @@ CreateSlicedGemmArbitraryNumberOfArgumentsThunk(
int64_t offset_length = sizeof(int64_t);
// Preparing buffer allocation slices for thunk creations.
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations;
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/0, rhs_length, /*color=*/0));
BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0,
std::vector<BufferAllocation> fake_allocations;
fake_allocations.reserve(4);
fake_allocations.emplace_back(/*index=*/0, rhs_length, /*color=*/0);
BufferAllocation::Slice slice_lhs_fake(&fake_allocations.back(), 0,
rhs_length);
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/1, rhs_length, /*color=*/0));
BufferAllocation::Slice slice_rhs_fake(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/1, rhs_length, /*color=*/0);
BufferAllocation::Slice slice_rhs_fake(&fake_allocations.back(), 0,
rhs_length);
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/2, out_length, /*color=*/0));
BufferAllocation::Slice slice_out_fake(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/2, out_length, /*color=*/0);
BufferAllocation::Slice slice_out_fake(&fake_allocations.back(), 0,
out_length);
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/3, 1024 * 1024, /*color=*/0));
BufferAllocation::Slice slice_workspace_fake(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/3, 1024 * 1024, /*color=*/0);
BufferAllocation::Slice slice_workspace_fake(&fake_allocations.back(), 0,
1024 * 1024);
auto alloc_lhs =
@ -1220,29 +1203,24 @@ CreateSlicedTupledOperandGemmThunk(
int64_t offset_length = sizeof(int64_t);
// Preparing buffer allocation slices for thunk creations.
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations;
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/0, rhs_length, /*color=*/0));
BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0,
std::vector<BufferAllocation> fake_allocations;
fake_allocations.reserve(4);
fake_allocations.emplace_back(/*index=*/0, rhs_length, /*color=*/0);
BufferAllocation::Slice slice_lhs_fake(&fake_allocations.back(), 0,
rhs_length);
auto alloc_lhs = std::make_unique<BufferAllocation>(
/*index=*/0, 3 * lhs_length, /*color=*/0);
BufferAllocation::Slice slice_lhs(alloc_lhs.get(), lhs_length, lhs_length);
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/1, rhs_length, /*color=*/0));
BufferAllocation::Slice slice_rhs(fake_allocations.back().get(), 0,
rhs_length);
fake_allocations.emplace_back(/*index=*/1, rhs_length, /*color=*/0);
BufferAllocation::Slice slice_rhs(&fake_allocations.back(), 0, rhs_length);
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/2, out_length, /*color=*/0));
BufferAllocation::Slice slice_out(fake_allocations.back().get(), 0,
out_length);
fake_allocations.emplace_back(/*index=*/2, out_length, /*color=*/0);
BufferAllocation::Slice slice_out(&fake_allocations.back(), 0, out_length);
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/3, 1024 * 1024, /*color=*/0));
BufferAllocation::Slice slice_workspace(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/3, 1024 * 1024, /*color=*/0);
BufferAllocation::Slice slice_workspace(&fake_allocations.back(), 0,
1024 * 1024);
auto alloc_lhs_offset_0 = std::make_unique<BufferAllocation>(
@ -1398,17 +1376,16 @@ TEST_F(DynamicSliceThunkTest, SlicedMemcpyOOB) {
// Prepare embedded and dynamic slice thunks.
// Preparing buffer allocation slices for thunk creations.
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations(2);
std::vector<BufferAllocation> fake_allocations;
fake_allocations.reserve(2);
// Fake slices for embedded thunk creation.
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/0, slice_length, /*color=*/0));
BufferAllocation::Slice slice_src_fake(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/0, slice_length, /*color=*/0);
BufferAllocation::Slice slice_src_fake(&fake_allocations.back(), 0,
slice_length);
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/1, slice_length, /*color=*/0));
BufferAllocation::Slice slice_dst_fake(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/1, slice_length, /*color=*/0);
BufferAllocation::Slice slice_dst_fake(&fake_allocations.back(), 0,
slice_length);
BufferAllocation alloc_src(/*index=*/0, src_length, /*color=*/0);
@ -1591,25 +1568,22 @@ CreateSlicedOperandsSameBufferGemmThunk(
int64_t offset_length = sizeof(int64_t);
// Preparing buffer allocation slices for thunk creations.
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations;
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/0, rhs_length, /*color=*/0));
BufferAllocation::Slice slice_lhs_fake(fake_allocations.back().get(), 0,
std::vector<BufferAllocation> fake_allocations;
fake_allocations.reserve(4);
fake_allocations.emplace_back(/*index=*/0, rhs_length, /*color=*/0);
BufferAllocation::Slice slice_lhs_fake(&fake_allocations.back(), 0,
rhs_length);
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/1, rhs_length, /*color=*/0));
BufferAllocation::Slice slice_rhs_fake(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/1, rhs_length, /*color=*/0);
BufferAllocation::Slice slice_rhs_fake(&fake_allocations.back(), 0,
rhs_length);
fake_allocations.push_back(
std::make_unique<BufferAllocation>(/*index=*/2, out_length, /*color=*/0));
BufferAllocation::Slice slice_out_fake(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/2, out_length, /*color=*/0);
BufferAllocation::Slice slice_out_fake(&fake_allocations.back(), 0,
out_length);
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/3, 1024 * 1024, /*color=*/0));
BufferAllocation::Slice slice_workspace_fake(fake_allocations.back().get(), 0,
fake_allocations.emplace_back(/*index=*/3, 1024 * 1024, /*color=*/0);
BufferAllocation::Slice slice_workspace_fake(&fake_allocations.back(), 0,
1024 * 1024);
auto alloc = std::make_unique<BufferAllocation>(
@ -1804,11 +1778,11 @@ CreateHostInductionVariableAndOffsetEvaluationThunk(
int64_t out_length = sizeof(float) * 1 * 1;
// Preparing buffer allocation slices for thunk creations.
std::vector<std::unique_ptr<BufferAllocation>> fake_allocations;
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/0, /*size=*/rhs_length, /*color=*/0));
std::vector<BufferAllocation> fake_allocations;
fake_allocations.reserve(4);
fake_allocations.emplace_back(/*index=*/0, /*size=*/rhs_length, /*color=*/0);
BufferAllocation::Slice slice_lhs_fake(
/*allocation=*/fake_allocations.back().get(), /*offset=*/0,
/*allocation=*/&fake_allocations.back(), /*offset=*/0,
/*size=*/rhs_length);
auto alloc_lhs = std::make_unique<BufferAllocation>(
@ -1816,22 +1790,19 @@ CreateHostInductionVariableAndOffsetEvaluationThunk(
BufferAllocation::Slice slice_lhs(alloc_lhs.get(), /*offset=*/0,
/*size=*/lhs_length);
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/1, /*size=*/rhs_length, /*color=*/0));
fake_allocations.emplace_back(/*index=*/1, /*size=*/rhs_length, /*color=*/0);
BufferAllocation::Slice slice_rhs(
/*allocation=*/fake_allocations.back().get(), /*offset=*/0,
/*allocation=*/&fake_allocations.back(), /*offset=*/0,
/*size=*/rhs_length);
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/2, /*size=*/out_length, /*color=*/0));
fake_allocations.emplace_back(/*index=*/2, /*size=*/out_length, /*color=*/0);
BufferAllocation::Slice slice_out(
/*allocation=*/fake_allocations.back().get(), /*offset=*/0,
/*allocation=*/&fake_allocations.back(), /*offset=*/0,
/*size=*/out_length);
fake_allocations.push_back(std::make_unique<BufferAllocation>(
/*index=*/3, /*size=*/1024 * 1024, /*color=*/0));
fake_allocations.emplace_back(/*index=*/3, /*size=*/1024 * 1024, /*color=*/0);
BufferAllocation::Slice slice_workspace(
/*allocation=*/fake_allocations.back().get(), /*offset=*/0,
/*allocation=*/&fake_allocations.back(), /*offset=*/0,
/*size=*/1024 * 1024);
backing_allocations.push_back(std::move(alloc_lhs));