mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add de/serializaton of fake_allocations in DynamicSliceThunk.
PiperOrigin-RevId: 826541399
This commit is contained in:
parent
ecc2510eb0
commit
6ff7f9c87f
|
|
@ -2452,6 +2452,7 @@ tf_proto_library(
|
|||
"//xla:xla_data_proto",
|
||||
"//xla/core/host_offloading:host_offloading_executable_proto",
|
||||
"//xla/service:buffer_assignment_proto",
|
||||
"//xla/service:hlo_proto",
|
||||
"//xla/service/gpu:backend_configs",
|
||||
"//xla/service/gpu:gpu_conv_runner_proto",
|
||||
"//xla/service/gpu:gpu_norm_runner_proto",
|
||||
|
|
@ -2490,6 +2491,7 @@ cc_library(
|
|||
":convolution_thunk",
|
||||
":copy_thunk",
|
||||
":cudnn_thunk",
|
||||
":dynamic_slice_thunk",
|
||||
":fft_thunk",
|
||||
":gemm_thunk",
|
||||
":gpublas_lt_matmul_thunk",
|
||||
|
|
|
|||
|
|
@ -610,14 +610,19 @@ absl::StatusOr<ThunkProto> DynamicSliceThunk::ToProto() const {
|
|||
->mutable_offset_as_function_of_indvar_modules_metadata(),
|
||||
offset_as_function_of_indvar_metadata_->ToProto());
|
||||
}
|
||||
|
||||
// fake_allocations
|
||||
for (const auto& fake_allocation : fake_allocations_) {
|
||||
*dynamic_slice_proto->add_fake_allocations() = fake_allocation.ToProto();
|
||||
}
|
||||
|
||||
return proto;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
|
||||
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
|
||||
absl::Span<const BufferAllocation> buffer_allocations,
|
||||
absl::Span<const BufferAllocation> fake_allocations,
|
||||
const Deserializer& deserializer) {
|
||||
const DeserializerWithCustomAllocations& deserializer) {
|
||||
// offset_as_function_of_indvar_metadata
|
||||
std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
|
||||
offset_as_function_of_indvar_metadata;
|
||||
|
|
@ -674,20 +679,24 @@ absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
|
|||
}
|
||||
}
|
||||
|
||||
// fake_allocations
|
||||
std::vector<BufferAllocation> fake_allocations;
|
||||
for (const auto& fake_allocation_proto : proto.fake_allocations()) {
|
||||
fake_allocations.push_back(
|
||||
BufferAllocation::FromProto(fake_allocation_proto));
|
||||
}
|
||||
|
||||
// embedded_thunk
|
||||
std::vector<std::unique_ptr<Thunk>> embedded_thunks;
|
||||
for (const auto& thunk_proto : proto.embedded_thunk().thunks()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> embedded_thunk,
|
||||
deserializer(thunk_proto));
|
||||
deserializer(thunk_proto, fake_allocations));
|
||||
embedded_thunks.push_back(std::move(embedded_thunk));
|
||||
}
|
||||
|
||||
// leave fake_allocations empty, because we manage their lifetime outside
|
||||
// of this function.
|
||||
return std::make_unique<DynamicSliceThunk>(
|
||||
thunk_info, std::make_unique<ThunkSequence>(std::move(embedded_thunks)),
|
||||
std::move(arguments),
|
||||
/*fake_allocations=*/std::vector<BufferAllocation>(), std::move(offsets),
|
||||
std::move(arguments), std::move(fake_allocations), std::move(offsets),
|
||||
std::move(orig_shapes), std::move(sliced_shapes),
|
||||
std::move(offset_byte_sizes),
|
||||
std::move(offset_as_function_of_indvar_metadata));
|
||||
|
|
|
|||
|
|
@ -181,17 +181,11 @@ class DynamicSliceThunk : public Thunk {
|
|||
// `buffer_allocations`: the actual buffer allocations; required to parse the
|
||||
// `arguments` (BufferAllocation::Slice) -- the tensors that we are later
|
||||
// slicing from.
|
||||
// `fake_allocations`: The fake allocations that are used as
|
||||
// placeholders during creation of the embedded thunk. These are being
|
||||
// replaced during execution in `ExecuteOnStream` with the actual (dynamic)
|
||||
// slices. We have to create these outside of this method to manage their
|
||||
// lifetime correctly.
|
||||
// `deserializer`: The deserializer is used to deserialize the embedded thunk.
|
||||
static absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> FromProto(
|
||||
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
|
||||
absl::Span<const BufferAllocation> buffer_allocations,
|
||||
absl::Span<const BufferAllocation> fake_allocations,
|
||||
const Deserializer& deserializer);
|
||||
const DeserializerWithCustomAllocations& deserializer);
|
||||
|
||||
std::optional<const OffsetAsFunctionOfIndvarModulesMetadata*>
|
||||
get_offset_function() const {
|
||||
|
|
|
|||
|
|
@ -110,16 +110,17 @@ void CheckProtoRoundTrip(const DynamicSliceThunk& thunk,
|
|||
}
|
||||
}
|
||||
|
||||
Thunk::Deserializer deserializer =
|
||||
[&buffer_allocations](const ThunkProto& thunk_proto)
|
||||
Thunk::DeserializerWithCustomAllocations deserializer =
|
||||
[](const ThunkProto& thunk_proto,
|
||||
absl::Span<const BufferAllocation> fake_allocations_span)
|
||||
-> absl::StatusOr<std::unique_ptr<Thunk>> {
|
||||
return DeserializeThunkProto(thunk_proto, buffer_allocations);
|
||||
return DeserializeThunkProto(thunk_proto, fake_allocations_span);
|
||||
};
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto thunk_from_proto,
|
||||
DynamicSliceThunk::FromProto(Thunk::ThunkInfo(), proto,
|
||||
/*buffer_allocations=*/buffer_allocations,
|
||||
/*fake_allocations=*/fake_allocations_span,
|
||||
deserializer));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto proto_roundtrip, thunk_from_proto->ToProto());
|
||||
auto dynamic_slice_thunk_proto_roundtrip =
|
||||
|
|
|
|||
|
|
@ -571,6 +571,10 @@ class Thunk {
|
|||
absl::AnyInvocable<absl::StatusOr<std::unique_ptr<Thunk>>(
|
||||
const ThunkProto&) const>;
|
||||
|
||||
using DeserializerWithCustomAllocations =
|
||||
absl::AnyInvocable<absl::StatusOr<std::unique_ptr<Thunk>>(
|
||||
const ThunkProto&, absl::Span<const BufferAllocation>) const>;
|
||||
|
||||
void add_control_predecessor(const Thunk* control_predecessor) {
|
||||
control_predecessors_.push_back(control_predecessor);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import "xla/service/buffer_assignment.proto";
|
|||
import "xla/service/gpu/gpu_conv_runner.proto";
|
||||
import "xla/service/gpu/gpu_norm_runner.proto";
|
||||
import "xla/service/gpu/launch_dimensions.proto";
|
||||
import "xla/service/hlo.proto";
|
||||
import "xla/stream_executor/gpu/gpu_blas_lt.proto";
|
||||
import "xla/stream_executor/gpu/tma_metadata.proto";
|
||||
import "xla/stream_executor/launch_dim.proto";
|
||||
|
|
@ -167,6 +168,7 @@ message DynamicSliceThunkProto {
|
|||
repeated OptionalInt64Proto offset_byte_sizes = 6;
|
||||
optional OffsetAsFunctionOfIndvarModulesMetadataProto
|
||||
offset_as_function_of_indvar_modules_metadata = 7;
|
||||
repeated BufferAllocationProto fake_allocations = 8;
|
||||
}
|
||||
|
||||
message MemzeroThunkProto {
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||
#include "xla/backends/gpu/runtime/convolution_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/copy_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/cudnn_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/dynamic_slice_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/fft_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/gemm_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h"
|
||||
|
|
@ -167,6 +168,16 @@ absl::StatusOr<std::unique_ptr<Thunk>> DeserializeThunkProto(
|
|||
return Memset32BitValueThunk::FromProto(
|
||||
std::move(thunk_info), thunk_proto.memset32bit_value_thunk(),
|
||||
buffer_allocations);
|
||||
case ThunkProto::kDynamicSliceThunk: {
|
||||
auto deserializer =
|
||||
[](const ThunkProto& thunk_proto,
|
||||
absl::Span<const BufferAllocation> custom_allocations) {
|
||||
return DeserializeThunkProto(thunk_proto, custom_allocations);
|
||||
};
|
||||
return DynamicSliceThunk::FromProto(std::move(thunk_info),
|
||||
thunk_proto.dynamic_slice_thunk(),
|
||||
buffer_allocations, deserializer);
|
||||
}
|
||||
default:
|
||||
std::optional<absl::string_view> unsupported_thunk_type =
|
||||
GetStoredThunkTypeName(thunk_proto);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user