mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Use Deserializer lambda for embedded thunks in DynamicSliceThunk
PiperOrigin-RevId: 826474606
This commit is contained in:
parent
26d0882419
commit
ecc2510eb0
|
|
@ -226,7 +226,6 @@ cc_library(
|
||||||
":sequential_thunk",
|
":sequential_thunk",
|
||||||
":thunk",
|
":thunk",
|
||||||
":thunk_proto_cc",
|
":thunk_proto_cc",
|
||||||
":thunk_proto_deserialization",
|
|
||||||
"//xla:literal",
|
"//xla:literal",
|
||||||
"//xla:literal_util",
|
"//xla:literal_util",
|
||||||
"//xla:shape_util",
|
"//xla:shape_util",
|
||||||
|
|
@ -270,6 +269,7 @@ xla_test(
|
||||||
":gemm_thunk",
|
":gemm_thunk",
|
||||||
":sequential_thunk",
|
":sequential_thunk",
|
||||||
":thunk",
|
":thunk",
|
||||||
|
":thunk_proto_deserialization",
|
||||||
"//xla:shape_util",
|
"//xla:shape_util",
|
||||||
"//xla:xla_data_proto_cc",
|
"//xla:xla_data_proto_cc",
|
||||||
"//xla/ffi",
|
"//xla/ffi",
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,6 @@ limitations under the License.
|
||||||
#include "xla/backends/gpu/runtime/sequential_thunk.h"
|
#include "xla/backends/gpu/runtime/sequential_thunk.h"
|
||||||
#include "xla/backends/gpu/runtime/thunk.h"
|
#include "xla/backends/gpu/runtime/thunk.h"
|
||||||
#include "xla/backends/gpu/runtime/thunk.pb.h"
|
#include "xla/backends/gpu/runtime/thunk.pb.h"
|
||||||
#include "xla/backends/gpu/runtime/thunk_proto_deserialization.h"
|
|
||||||
#include "xla/hlo/evaluator/hlo_evaluator.h"
|
#include "xla/hlo/evaluator/hlo_evaluator.h"
|
||||||
#include "xla/hlo/ir/hlo_module.h"
|
#include "xla/hlo/ir/hlo_module.h"
|
||||||
#include "xla/literal.h"
|
#include "xla/literal.h"
|
||||||
|
|
@ -617,7 +616,8 @@ absl::StatusOr<ThunkProto> DynamicSliceThunk::ToProto() const {
|
||||||
absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
|
absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
|
||||||
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
|
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
|
||||||
absl::Span<const BufferAllocation> buffer_allocations,
|
absl::Span<const BufferAllocation> buffer_allocations,
|
||||||
absl::Span<const BufferAllocation> fake_allocations) {
|
absl::Span<const BufferAllocation> fake_allocations,
|
||||||
|
const Deserializer& deserializer) {
|
||||||
// offset_as_function_of_indvar_metadata
|
// offset_as_function_of_indvar_metadata
|
||||||
std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
|
std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
|
||||||
offset_as_function_of_indvar_metadata;
|
offset_as_function_of_indvar_metadata;
|
||||||
|
|
@ -677,9 +677,9 @@ absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> DynamicSliceThunk::FromProto(
|
||||||
// embedded_thunk
|
// embedded_thunk
|
||||||
std::vector<std::unique_ptr<Thunk>> embedded_thunks;
|
std::vector<std::unique_ptr<Thunk>> embedded_thunks;
|
||||||
for (const auto& thunk_proto : proto.embedded_thunk().thunks()) {
|
for (const auto& thunk_proto : proto.embedded_thunk().thunks()) {
|
||||||
TF_ASSIGN_OR_RETURN(auto thunk,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> embedded_thunk,
|
||||||
DeserializeThunkProto(thunk_proto, fake_allocations));
|
deserializer(thunk_proto));
|
||||||
embedded_thunks.push_back(std::move(thunk));
|
embedded_thunks.push_back(std::move(embedded_thunk));
|
||||||
}
|
}
|
||||||
|
|
||||||
// leave fake_allocations empty, because we manage their lifetime outside
|
// leave fake_allocations empty, because we manage their lifetime outside
|
||||||
|
|
|
||||||
|
|
@ -186,10 +186,12 @@ class DynamicSliceThunk : public Thunk {
|
||||||
// replaced during execution in `ExecuteOnStream` with the actual (dynamic)
|
// replaced during execution in `ExecuteOnStream` with the actual (dynamic)
|
||||||
// slices. We have to create these outside of this method to manage their
|
// slices. We have to create these outside of this method to manage their
|
||||||
// lifetime correctly.
|
// lifetime correctly.
|
||||||
|
// `deserializer`: The deserializer is used to deserialize the embedded thunk.
|
||||||
static absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> FromProto(
|
static absl::StatusOr<std::unique_ptr<DynamicSliceThunk>> FromProto(
|
||||||
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
|
ThunkInfo thunk_info, const DynamicSliceThunkProto& proto,
|
||||||
absl::Span<const BufferAllocation> buffer_allocations,
|
absl::Span<const BufferAllocation> buffer_allocations,
|
||||||
absl::Span<const BufferAllocation> fake_allocations);
|
absl::Span<const BufferAllocation> fake_allocations,
|
||||||
|
const Deserializer& deserializer);
|
||||||
|
|
||||||
std::optional<const OffsetAsFunctionOfIndvarModulesMetadata*>
|
std::optional<const OffsetAsFunctionOfIndvarModulesMetadata*>
|
||||||
get_offset_function() const {
|
get_offset_function() const {
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||||
#include "xla/backends/gpu/runtime/gemm_thunk.h"
|
#include "xla/backends/gpu/runtime/gemm_thunk.h"
|
||||||
#include "xla/backends/gpu/runtime/sequential_thunk.h"
|
#include "xla/backends/gpu/runtime/sequential_thunk.h"
|
||||||
#include "xla/backends/gpu/runtime/thunk.h"
|
#include "xla/backends/gpu/runtime/thunk.h"
|
||||||
|
#include "xla/backends/gpu/runtime/thunk_proto_deserialization.h"
|
||||||
#include "xla/ffi/attribute_map.h"
|
#include "xla/ffi/attribute_map.h"
|
||||||
#include "xla/ffi/ffi.h"
|
#include "xla/ffi/ffi.h"
|
||||||
#include "xla/ffi/ffi_api.h"
|
#include "xla/ffi/ffi_api.h"
|
||||||
|
|
@ -108,11 +109,18 @@ void CheckProtoRoundTrip(const DynamicSliceThunk& thunk,
|
||||||
BufferAllocation(i, arguments[i].value().allocation()->size(), 0));
|
BufferAllocation(i, arguments[i].value().allocation()->size(), 0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Thunk::Deserializer deserializer =
|
||||||
|
[&buffer_allocations](const ThunkProto& thunk_proto)
|
||||||
|
-> absl::StatusOr<std::unique_ptr<Thunk>> {
|
||||||
|
return DeserializeThunkProto(thunk_proto, buffer_allocations);
|
||||||
|
};
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto thunk_from_proto,
|
auto thunk_from_proto,
|
||||||
DynamicSliceThunk::FromProto(Thunk::ThunkInfo(), proto,
|
DynamicSliceThunk::FromProto(Thunk::ThunkInfo(), proto,
|
||||||
/*buffer_allocations=*/buffer_allocations,
|
/*buffer_allocations=*/buffer_allocations,
|
||||||
/*fake_allocations=*/fake_allocations_span));
|
/*fake_allocations=*/fake_allocations_span,
|
||||||
|
deserializer));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto proto_roundtrip, thunk_from_proto->ToProto());
|
TF_ASSERT_OK_AND_ASSIGN(auto proto_roundtrip, thunk_from_proto->ToProto());
|
||||||
auto dynamic_slice_thunk_proto_roundtrip =
|
auto dynamic_slice_thunk_proto_roundtrip =
|
||||||
proto_roundtrip.dynamic_slice_thunk();
|
proto_roundtrip.dynamic_slice_thunk();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user