mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Use Deserializer lambda for embedded thunks in DynamicSliceThunk
PiperOrigin-RevId: 826474606
This commit is contained in:
parent
26d0882419
commit
ecc2510eb0
|
|
@ -226,7 +226,6 @@ cc_library(
|
|||
":sequential_thunk",
|
||||
":thunk",
|
||||
":thunk_proto_cc",
|
||||
":thunk_proto_deserialization",
|
||||
"//xla:literal",
|
||||
"//xla:literal_util",
|
||||
"//xla:shape_util",
|
||||
|
|
@ -270,6 +269,7 @@ xla_test(
|
|||
":gemm_thunk",
|
||||
":sequential_thunk",
|
||||
":thunk",
|
||||
":thunk_proto_deserialization",
|
||||
"//xla:shape_util",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/ffi",
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ limitations under the License.
|
|||
#include "xla/backends/gpu/runtime/sequential_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/thunk.h"
|
||||
#include "xla/backends/gpu/runtime/thunk.pb.h"
|
||||
#include "xla/backends/gpu/runtime/thunk_proto_deserialization.h"
|
||||
#include "xla/hlo/evaluator/hlo_evaluator.h"
|
||||
#include "xla/hlo/ir/hlo_module.h"
|
||||
#include "xla/literal.h"
|
||||
|
|
@ -617,7 +616,8 @@ absl::StatusOr<ThunkProto> DynamicSliceThunk::ToProto() const {
|
|||
absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
|
||||
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
|
||||
absl::Span<const BufferAllocation> buffer_allocations,
|
||||
absl::Span<const BufferAllocation> fake_allocations) {
|
||||
absl::Span<const BufferAllocation> fake_allocations,
|
||||
const Deserializer& deserializer) {
|
||||
// offset_as_function_of_indvar_metadata
|
||||
std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
|
||||
offset_as_function_of_indvar_metadata;
|
||||
|
|
@ -677,9 +677,9 @@ absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
|
|||
// embedded_thunk
|
||||
std::vector<std::unique_ptr<Thunk>> embedded_thunks;
|
||||
for (const auto& thunk_proto : proto.embedded_thunk().thunks()) {
|
||||
TF_ASSIGN_OR_RETURN(auto thunk,
|
||||
DeserializeThunkProto(thunk_proto, fake_allocations));
|
||||
embedded_thunks.push_back(std::move(thunk));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> embedded_thunk,
|
||||
deserializer(thunk_proto));
|
||||
embedded_thunks.push_back(std::move(embedded_thunk));
|
||||
}
|
||||
|
||||
// leave fake_allocations empty, because we manage their lifetime outside
|
||||
|
|
|
|||
|
|
@ -186,10 +186,12 @@ class DynamicSliceThunk : public Thunk {
|
|||
// replaced during execution in `ExecuteOnStream` with the actual (dynamic)
|
||||
// slices. We have to create these outside of this method to manage their
|
||||
// lifetime correctly.
|
||||
// `deserializer`: The deserializer is used to deserialize the embedded thunk.
|
||||
static absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> FromProto(
|
||||
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
|
||||
absl::Span<const BufferAllocation> buffer_allocations,
|
||||
absl::Span<const BufferAllocation> fake_allocations);
|
||||
absl::Span<const BufferAllocation> fake_allocations,
|
||||
const Deserializer& deserializer);
|
||||
|
||||
std::optional<const OffsetAsFunctionOfIndvarModulesMetadata*>
|
||||
get_offset_function() const {
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||
#include "xla/backends/gpu/runtime/gemm_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/sequential_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/thunk.h"
|
||||
#include "xla/backends/gpu/runtime/thunk_proto_deserialization.h"
|
||||
#include "xla/ffi/attribute_map.h"
|
||||
#include "xla/ffi/ffi.h"
|
||||
#include "xla/ffi/ffi_api.h"
|
||||
|
|
@ -108,11 +109,18 @@ void CheckProtoRoundTrip(const DynamicSliceThunk& thunk,
|
|||
BufferAllocation(i, arguments[i].value().allocation()->size(), 0));
|
||||
}
|
||||
}
|
||||
|
||||
Thunk::Deserializer deserializer =
|
||||
[&buffer_allocations](const ThunkProto& thunk_proto)
|
||||
-> absl::StatusOr<std::unique_ptr<Thunk>> {
|
||||
return DeserializeThunkProto(thunk_proto, buffer_allocations);
|
||||
};
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto thunk_from_proto,
|
||||
DynamicSliceThunk::FromProto(Thunk::ThunkInfo(), proto,
|
||||
/*buffer_allocations=*/buffer_allocations,
|
||||
/*fake_allocations=*/fake_allocations_span));
|
||||
/*fake_allocations=*/fake_allocations_span,
|
||||
deserializer));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto proto_roundtrip, thunk_from_proto->ToProto());
|
||||
auto dynamic_slice_thunk_proto_roundtrip =
|
||||
proto_roundtrip.dynamic_slice_thunk();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user