Use factory function to create CubSortThunk

The `CubSortThunk` constructor was calling a function that returns a `absl::StatusOr`, and ignoring non-ok statuses and just accessing the value.

Presumably in prod the status is always ok, but making this failure case explicit.

PiperOrigin-RevId: 826410861
This commit is contained in:
Eusebio Durán Montaña 2025-10-31 03:17:23 -07:00 committed by TensorFlower Gardener
parent adfd891fde
commit e32f20dd91
4 changed files with 45 additions and 21 deletions

View File

@ -684,6 +684,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log", "@com_google_absl//absl/log",
"@com_google_absl//absl/log:check", "@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/container/inlined_vector.h" #include "absl/container/inlined_vector.h"
#include "absl/log/check.h" #include "absl/log/check.h"
#include "absl/log/log.h" #include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
@ -283,16 +284,29 @@ CubSortRunnerInterface::Create(PrimitiveType type,
: CreateCubSortRunner(type, platform_name); : CreateCubSortRunner(type, platform_name);
} }
CubSortThunk::CubSortThunk( absl::StatusOr<std::unique_ptr<CubSortThunk>> CubSortThunk::Create(
ThunkInfo thunk_info, PrimitiveType type, ThunkInfo thunk_info, PrimitiveType type,
std::optional<PrimitiveType> value_type, std::optional<PrimitiveType> value_type,
absl::InlinedVector<BufferAllocation::Slice, 2> operands, absl::InlinedVector<BufferAllocation::Slice, 2> operands,
absl::InlinedVector<BufferAllocation::Slice, 2> results, absl::InlinedVector<BufferAllocation::Slice, 2> results,
BufferAllocation::Slice scratch, bool descending, int64_t batch_size, BufferAllocation::Slice scratch, bool descending, int64_t batch_size,
absl::string_view platform_name) absl::string_view platform_name) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<CubSortRunnerInterface> runner,
CubSortRunnerInterface::Create(type, value_type, platform_name));
return absl::WrapUnique<CubSortThunk>(
new CubSortThunk(thunk_info, std::move(runner), std::move(operands),
std::move(results), scratch, descending, batch_size));
}
CubSortThunk::CubSortThunk(
ThunkInfo thunk_info, std::unique_ptr<CubSortRunnerInterface> runner,
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
absl::InlinedVector<BufferAllocation::Slice, 2> results,
BufferAllocation::Slice scratch, bool descending, int64_t batch_size)
: Thunk(Thunk::kCubSort, thunk_info), : Thunk(Thunk::kCubSort, thunk_info),
runner_(CubSortRunnerInterface::Create(type, value_type, platform_name) runner_(std::move(runner)),
.value()),
operands_(std::move(operands)), operands_(std::move(operands)),
results_(std::move(results)), results_(std::move(results)),
scratch_(scratch), scratch_(scratch),

View File

@ -54,12 +54,13 @@ class CubSortRunnerInterface {
class CubSortThunk : public Thunk { class CubSortThunk : public Thunk {
public: public:
CubSortThunk(ThunkInfo thunk_info, PrimitiveType type, static absl::StatusOr<std::unique_ptr<CubSortThunk>> Create(
std::optional<PrimitiveType> value_type, ThunkInfo thunk_info, PrimitiveType type,
absl::InlinedVector<BufferAllocation::Slice, 2> operands, std::optional<PrimitiveType> value_type,
absl::InlinedVector<BufferAllocation::Slice, 2> results, absl::InlinedVector<BufferAllocation::Slice, 2> operands,
BufferAllocation::Slice scratch, bool descending, absl::InlinedVector<BufferAllocation::Slice, 2> results,
int64_t batch_size, absl::string_view platform_name); BufferAllocation::Slice scratch, bool descending, int64_t batch_size,
absl::string_view platform_name);
absl::Status ExecuteOnStream(const ExecuteParams& params) override { absl::Status ExecuteOnStream(const ExecuteParams& params) override {
return runner_->Run(params, this); return runner_->Run(params, this);
@ -72,6 +73,12 @@ class CubSortThunk : public Thunk {
int64_t batch_size() const { return batch_size_; } int64_t batch_size() const { return batch_size_; }
private: private:
CubSortThunk(ThunkInfo thunk_info,
std::unique_ptr<CubSortRunnerInterface> runner,
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
absl::InlinedVector<BufferAllocation::Slice, 2> results,
BufferAllocation::Slice scratch, bool descending,
int64_t batch_size);
std::unique_ptr<CubSortRunnerInterface> runner_; std::unique_ptr<CubSortRunnerInterface> runner_;
absl::InlinedVector<BufferAllocation::Slice, 2> operands_; absl::InlinedVector<BufferAllocation::Slice, 2> operands_;
absl::InlinedVector<BufferAllocation::Slice, 2> results_; absl::InlinedVector<BufferAllocation::Slice, 2> results_;

View File

@ -1072,17 +1072,19 @@ absl::Status IrEmitterUnnested::EmitCubDeviceRadixSort(
TF_ASSIGN_OR_RETURN(xla::SortOptions options, TF_ASSIGN_OR_RETURN(xla::SortOptions options,
instr->backend_config<xla::SortOptions>()); instr->backend_config<xla::SortOptions>());
const Shape& operand_shape = instr->operand(0)->shape(); const Shape& operand_shape = instr->operand(0)->shape();
auto thunk = std::make_unique<CubSortThunk>( TF_ASSIGN_OR_RETURN(
Thunk::ThunkInfo::WithProfileAnnotation( std::unique_ptr<CubSortThunk> thunk,
instr, ir_emitter_context_->GetNextThunkId()), CubSortThunk::Create(
operand_shape.element_type(), Thunk::ThunkInfo::WithProfileAnnotation(
instr->operand_count() == 2 instr, ir_emitter_context_->GetNextThunkId()),
? std::optional(instr->operand(1)->shape().element_type()) operand_shape.element_type(),
: std::nullopt, instr->operand_count() == 2
operands, results, scratch, options.descending(), ? std::optional(instr->operand(1)->shape().element_type())
Product(operand_shape.dimensions()) / : std::nullopt,
operand_shape.dimensions(operand_shape.dimensions().size() - 1), operands, results, scratch, options.descending(),
ir_emitter_context_->platform_name()); Product(operand_shape.dimensions()) /
operand_shape.dimensions(operand_shape.dimensions().size() - 1),
ir_emitter_context_->platform_name()));
AddThunkToThunkSequence(std::move(thunk)); AddThunkToThunkSequence(std::move(thunk));
return absl::OkStatus(); return absl::OkStatus();
} }