diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 3c0b4cc3445..6dbf2dfcd2c 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -2449,6 +2449,7 @@ tf_proto_library( ":convolution_filter_thunk_proto", ":dynamic_slice_thunk_proto", "//xla:xla_data_proto", + "//xla/core/host_offloading:host_offloading_executable_proto", "//xla/service:buffer_assignment_proto", "//xla/service/gpu:backend_configs", "//xla/service/gpu:gpu_conv_runner_proto", @@ -2744,6 +2745,7 @@ xla_test( deps = [ ":host_execute_thunk", ":thunk", + ":thunk_proto_cc", "//xla:executable_run_options", "//xla:literal", "//xla:literal_util", @@ -2770,12 +2772,14 @@ xla_test( "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:status_matchers", "//xla/tsl/platform:statusor", + "//xla/tsl/util/proto:proto_matchers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:casts", + "@local_tsl//tsl/platform:protobuf", ], ) diff --git a/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk.cc index f6215c5bd50..b65308411c0 100644 --- a/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk.cc @@ -19,11 +19,13 @@ limitations under the License. #include #include #include +#include #include #include #include #include "absl/base/call_once.h" +#include "absl/base/casts.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" @@ -425,24 +427,84 @@ HostExecuteStartThunk::HostExecuteStartThunk( Thunk::ThunkInfo thunk_info, const HostOffloadingExecutableProto& host_offloading_executable_proto, absl::InlinedVector args, - absl::InlinedVector results) + absl::InlinedVector results, + std::shared_ptr async_events) : Thunk(Thunk::Kind::kHostExecuteStart, std::move(thunk_info)), args_(std::move(args)), results_(std::move(results)), - executable_proto_(host_offloading_executable_proto), - async_events_(std::make_shared()) {} + executable_proto_(host_offloading_executable_proto) { + async_events_ = + async_events ? async_events : std::make_shared(); +} std::string HostExecuteStartThunk::ToString(int indent) const { return ""; } absl::StatusOr HostExecuteStartThunk::ToProto() const { - return Unimplemented("Not implemented yet."); + ThunkProto proto; + *proto.mutable_thunk_info() = thunk_info().ToProto(); + HostExecuteStartThunkProto* host_execute_start_thunk_proto = + proto.mutable_host_execute_start_thunk(); + + *host_execute_start_thunk_proto->mutable_executable_proto() = + executable_proto_; + + for (const auto& [slice, shape] : args_) { + ShapedSliceProto* arg_proto = host_execute_start_thunk_proto->add_args(); + TF_ASSIGN_OR_RETURN(*arg_proto->mutable_slice(), slice.ToProto()); + *arg_proto->mutable_shape() = shape.ToProto(); + } + + for (const auto& [slice, shape] : results_) { + ShapedSliceProto* result_proto = + host_execute_start_thunk_proto->add_results(); + TF_ASSIGN_OR_RETURN(*result_proto->mutable_slice(), slice.ToProto()); + *result_proto->mutable_shape() = shape.ToProto(); + } + + auto async_events_unique_id = GetAsyncEventsUniqueId(); + // By design, async_events_unique_id should always be present for + // HostExecuteStartThunk. + CHECK_NE(async_events_unique_id, std::nullopt); + + host_execute_start_thunk_proto->set_async_events_unique_id( + async_events_unique_id.value().value()); + + return proto; } absl::StatusOr> HostExecuteStartThunk::FromProto( ThunkInfo thunk_info, const HostExecuteStartThunkProto& proto, - absl::Span buffer_allocations) { - return Unimplemented("Not implemented yet."); + absl::Span buffer_allocations, + HostExecuteAsyncEventsMap& async_events_map) { + absl::InlinedVector args, results; + auto shaped_slice_from_proto = + [&](const auto& shaped_slice_protos, + absl::InlinedVector& + slices_and_shapes) -> absl::Status { + for (const auto& shaped_slice_proto : shaped_slice_protos) { + TF_ASSIGN_OR_RETURN(auto slice, + BufferAllocation::Slice::FromProto( + shaped_slice_proto.slice(), buffer_allocations)); + TF_ASSIGN_OR_RETURN(auto shape, + Shape::FromProto(shaped_slice_proto.shape())); + slices_and_shapes.push_back({slice, shape}); + } + return absl::OkStatus(); + }; + + TF_RETURN_IF_ERROR(shaped_slice_from_proto(proto.args(), args)); + TF_RETURN_IF_ERROR(shaped_slice_from_proto(proto.results(), results)); + + // If async_events_map already contains an entry for the given unique id, + // that means that the pairing done thunk is already serialized and we reuse + // the id to connect them. Otherwise, create a new entry. + auto [async_event_it, _] = async_events_map.try_emplace( + AsyncEventsUniqueId(proto.async_events_unique_id()), + std::make_shared()); + return std::make_unique( + thunk_info, proto.executable_proto(), std::move(args), std::move(results), + async_event_it->second); } static HostOffloadingAllocator* GetHostOffloadingAllocator( @@ -556,6 +618,14 @@ absl::Status HostExecuteStartThunk::ExecuteOnStream( return absl::OkStatus(); } +std::optional +HostExecuteStartThunk::GetAsyncEventsUniqueId() const { + CHECK(async_events_) + << "async_events_ must not be null in HostExecuteStartThunk"; + // We rely on the fact that the pointer to async_events_ is unique. + return absl::bit_cast(async_events_.get()); +} + // HostExecuteDoneThunk HostExecuteDoneThunk::HostExecuteDoneThunk( diff --git a/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk.h b/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk.h index 4771fad13cc..a0df8525dfc 100644 --- a/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_BACKENDS_GPU_RUNTIME_HOST_EXECUTE_THUNK_H_ #include +#include #include #include @@ -70,6 +71,10 @@ class HostExecuteAsyncEvents { events_ ABSL_GUARDED_BY(events_mu_); }; +using HostExecuteAsyncEventsMap = + absl::flat_hash_map>; + class HostExecuteStartThunk : public Thunk { public: struct SliceAndShape { @@ -95,9 +100,14 @@ class HostExecuteStartThunk : public Thunk { std::string ToString(int indent) const override; absl::StatusOr ToProto() const override; + + // If async_events_map already contains an entry for the given unique id, we + // reuse the id to connect the start and done thunks. Otherwise, insert a new + // entry into the map. static absl::StatusOr> FromProto( ThunkInfo thunk_info, const HostExecuteStartThunkProto& proto, - absl::Span buffer_allocations); + absl::Span buffer_allocations, + HostExecuteAsyncEventsMap& async_events_map); absl::Status Initialize(const InitializeParams& params) override; absl::Status ExecuteOnStream(const ExecuteParams& params) override; @@ -118,12 +128,14 @@ class HostExecuteStartThunk : public Thunk { return &executable_proto_; } - protected: + std::optional GetAsyncEventsUniqueId() const override; + HostExecuteStartThunk( Thunk::ThunkInfo thunk_info, const HostOffloadingExecutableProto& host_offloading_executable_proto, absl::InlinedVector args, - absl::InlinedVector results); + absl::InlinedVector results, + std::shared_ptr async_events = nullptr); private: absl::once_flag executable_init_flag_; diff --git a/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk_test.cc index b43fade7529..4d59720875a 100644 --- a/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/host_execute_thunk_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -52,6 +53,7 @@ limitations under the License. #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/status_matchers.h" #include "xla/tsl/platform/statusor.h" +#include "xla/tsl/util/proto/proto_matchers.h" #include "xla/util.h" #include "tsl/platform/casts.h" @@ -584,6 +586,60 @@ TEST(HostExecuteDoneThunkTest, WaitingOnErrorEvent) { absl_testing::StatusIs(absl::StatusCode::kInternal)); } +TEST(HostExecuteStartThunkTest, ProtoRoundTrip) { + static constexpr char const* kHloModule = R"( + HloModule module + ENTRY add_inplace { + p0 = s32[] parameter(0) + ROOT add = s32[] add(p0, p0) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, + ParseAndReturnUnverifiedModule(kHloModule, {})); + + BufferAllocation alloc_arg(/*index=*/0, 4, /*color=*/0); + BufferAllocation alloc_result(/*index=*/1, 4, /*color=*/0); + + BufferAllocation::Slice slice_arg(&alloc_arg, 0, 4); + BufferAllocation::Slice slice_result(&alloc_result, 0, 4); + + TF_ASSERT_OK_AND_ASSIGN(auto thunk, + CreateHostExecuteStartThunk( + Thunk::ThunkInfo(), *hlo_module, + {{slice_arg, ShapeUtil::MakeShape(S32, {})}}, + {{slice_result, ShapeUtil::MakeShape(S32, {})}})); + + TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, thunk->ToProto()); + + std::vector buffer_allocations = { + BufferAllocation(/*index=*/0, /*size=*/4, /*color=*/0), + BufferAllocation(/*index=*/1, /*size=*/4, /*color=*/0)}; + + TF_ASSERT_OK_AND_ASSIGN(Thunk::ThunkInfo thunk_info, + Thunk::ThunkInfo::FromProto(proto.thunk_info())); + HostExecuteAsyncEventsMap async_events_map; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr round_trip_thunk, + HostExecuteStartThunk::FromProto(thunk_info, + proto.host_execute_start_thunk(), + buffer_allocations, async_events_map)); + + TF_ASSERT_OK_AND_ASSIGN(ThunkProto round_trip_proto, + round_trip_thunk->ToProto()); + EXPECT_EQ(async_events_map.size(), 1); + EXPECT_EQ(async_events_map.begin()->first, + thunk->GetAsyncEventsUniqueId().value()); + + // ids are expected to be different, so drop them for the comparison. + round_trip_proto.mutable_host_execute_start_thunk() + ->clear_async_events_unique_id(); + proto.mutable_host_execute_start_thunk()->clear_async_events_unique_id(); + + EXPECT_THAT(round_trip_proto, tsl::proto_testing::EqualsProto(proto)); +} + } // namespace } // namespace gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/thunk.proto b/third_party/xla/xla/backends/gpu/runtime/thunk.proto index 0b100285aee..7543054d6c0 100644 --- a/third_party/xla/xla/backends/gpu/runtime/thunk.proto +++ b/third_party/xla/xla/backends/gpu/runtime/thunk.proto @@ -19,6 +19,7 @@ package xla.gpu; import "xla/backends/gpu/runtime/convolution_filter_thunk.proto"; import "xla/backends/gpu/runtime/dynamic_slice_thunk.proto"; +import "xla/core/host_offloading/host_offloading_executable.proto"; import "xla/service/buffer_assignment.proto"; import "xla/service/gpu/gpu_conv_runner.proto"; import "xla/service/gpu/gpu_norm_runner.proto"; @@ -141,8 +142,21 @@ message CudnnThunkProto { optional int64 sdpa_dropout_seed = 3; } -message HostExecuteStartThunkProto {} -message HostExecuteDoneThunkProto {} +message ShapedSliceProto { + xla.buffer_assignment.BufferAllocationSliceProto slice = 1; + xla.ShapeProto shape = 2; +} + +message HostExecuteStartThunkProto { + HostOffloadingExecutableProto executable_proto = 1; + repeated ShapedSliceProto args = 2; + repeated ShapedSliceProto results = 3; + uint64 async_events_unique_id = 4; +} + +message HostExecuteDoneThunkProto { + uint64 async_events_unique_id = 1; +} message DynamicSliceThunkProto { SequentialThunkProto embedded_thunk = 1; @@ -164,11 +178,6 @@ message Memset32BitValueThunkProto { uint32 value = 2; } -message ShapedSliceProto { - xla.buffer_assignment.BufferAllocationSliceProto slice = 1; - xla.ShapeProto shape = 2; -} - message InfeedThunkProto { repeated ShapedSliceProto dest_slices = 1; } diff --git a/third_party/xla/xla/core/host_offloading/BUILD b/third_party/xla/xla/core/host_offloading/BUILD index fcd71ac95e3..a4d5c911348 100644 --- a/third_party/xla/xla/core/host_offloading/BUILD +++ b/third_party/xla/xla/core/host_offloading/BUILD @@ -118,7 +118,9 @@ strict_cc_test( tf_proto_library( name = "host_offloading_executable_proto", srcs = ["host_offloading_executable.proto"], - compatible_with = get_compatible_with_libtpu_portable(), + compatible_with = get_compatible_with_libtpu_portable() + [ + # copybara:uncomment "//buildenv/target:non_prod", + ], deps = [ "//xla/service:hlo_proto", "//xla/service/cpu:executable_proto",