From e32f20dd91c21bd40851f99e308c785cb227a4d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eusebio=20Dur=C3=A1n=20Monta=C3=B1a?= Date: Fri, 31 Oct 2025 03:17:23 -0700 Subject: [PATCH] Use factory function to create `CubSortThunk` The `CubSortThunk` constructor was calling a function that returns a `absl::StatusOr`, and ignoring non-ok statuses and just accessing the value. Presumably in prod the status is always ok, but making this failure case explicit. PiperOrigin-RevId: 826410861 --- .../xla/xla/backends/gpu/runtime/BUILD | 1 + .../backends/gpu/runtime/cub_sort_thunk.cc | 22 +++++++++++++---- .../xla/backends/gpu/runtime/cub_sort_thunk.h | 19 ++++++++++----- .../xla/service/gpu/ir_emitter_unnested.cc | 24 ++++++++++--------- 4 files changed, 45 insertions(+), 21 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 00503e46024..023bf01cad4 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -684,6 +684,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", diff --git a/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc index 1e3e58df979..915fc76b963 100644 --- a/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -283,16 +284,29 @@ CubSortRunnerInterface::Create(PrimitiveType type, : CreateCubSortRunner(type, platform_name); } -CubSortThunk::CubSortThunk( +absl::StatusOr> CubSortThunk::Create( ThunkInfo thunk_info, PrimitiveType type, std::optional value_type, absl::InlinedVector operands, absl::InlinedVector results, BufferAllocation::Slice scratch, bool descending, int64_t batch_size, - absl::string_view platform_name) + absl::string_view platform_name) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr runner, + CubSortRunnerInterface::Create(type, value_type, platform_name)); + + return absl::WrapUnique( + new CubSortThunk(thunk_info, std::move(runner), std::move(operands), + std::move(results), scratch, descending, batch_size)); +} + +CubSortThunk::CubSortThunk( + ThunkInfo thunk_info, std::unique_ptr runner, + absl::InlinedVector operands, + absl::InlinedVector results, + BufferAllocation::Slice scratch, bool descending, int64_t batch_size) : Thunk(Thunk::kCubSort, thunk_info), - runner_(CubSortRunnerInterface::Create(type, value_type, platform_name) - .value()), + runner_(std::move(runner)), operands_(std::move(operands)), results_(std::move(results)), scratch_(scratch), diff --git a/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.h b/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.h index 9165c2ad0ab..ef483b495d5 100644 --- a/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/cub_sort_thunk.h @@ -54,12 +54,13 @@ class CubSortRunnerInterface { class CubSortThunk : public Thunk { public: - CubSortThunk(ThunkInfo thunk_info, PrimitiveType type, - std::optional value_type, - absl::InlinedVector operands, - absl::InlinedVector results, - BufferAllocation::Slice scratch, bool descending, - int64_t batch_size, absl::string_view platform_name); + static absl::StatusOr> Create( + ThunkInfo thunk_info, PrimitiveType type, + std::optional value_type, + absl::InlinedVector operands, + absl::InlinedVector results, + BufferAllocation::Slice scratch, bool descending, int64_t batch_size, + absl::string_view platform_name); absl::Status ExecuteOnStream(const ExecuteParams& params) override { return runner_->Run(params, this); @@ -72,6 +73,12 @@ class CubSortThunk : public Thunk { int64_t batch_size() const { return batch_size_; } private: + CubSortThunk(ThunkInfo thunk_info, + std::unique_ptr runner, + absl::InlinedVector operands, + absl::InlinedVector results, + BufferAllocation::Slice scratch, bool descending, + int64_t batch_size); std::unique_ptr runner_; absl::InlinedVector operands_; absl::InlinedVector results_; diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 2b90b1b2bd3..f65f6bd6c2a 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -1072,17 +1072,19 @@ absl::Status IrEmitterUnnested::EmitCubDeviceRadixSort( TF_ASSIGN_OR_RETURN(xla::SortOptions options, instr->backend_config()); const Shape& operand_shape = instr->operand(0)->shape(); - auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation( - instr, ir_emitter_context_->GetNextThunkId()), - operand_shape.element_type(), - instr->operand_count() == 2 - ? std::optional(instr->operand(1)->shape().element_type()) - : std::nullopt, - operands, results, scratch, options.descending(), - Product(operand_shape.dimensions()) / - operand_shape.dimensions(operand_shape.dimensions().size() - 1), - ir_emitter_context_->platform_name()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr thunk, + CubSortThunk::Create( + Thunk::ThunkInfo::WithProfileAnnotation( + instr, ir_emitter_context_->GetNextThunkId()), + operand_shape.element_type(), + instr->operand_count() == 2 + ? std::optional(instr->operand(1)->shape().element_type()) + : std::nullopt, + operands, results, scratch, options.descending(), + Product(operand_shape.dimensions()) / + operand_shape.dimensions(operand_shape.dimensions().size() - 1), + ir_emitter_context_->platform_name())); AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); }