Use factory function to create CubSortThunk

The `CubSortThunk` constructor was calling a function that returns a `absl::StatusOr`, and ignoring non-ok statuses and just accessing the value.

Presumably in prod the status is always ok, but making this failure case explicit.

PiperOrigin-RevId: 826410861
This commit is contained in:
Eusebio Durán Montaña 2025-10-31 03:17:23 -07:00 committed by TensorFlower Gardener
parent adfd891fde
commit e32f20dd91
4 changed files with 45 additions and 21 deletions

View File

@ -684,6 +684,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
@ -283,16 +284,29 @@ CubSortRunnerInterface::Create(PrimitiveType type,
: CreateCubSortRunner(type, platform_name);
}
CubSortThunk::CubSortThunk(
absl::StatusOr<std::unique_ptr<CubSortThunk>> CubSortThunk::Create(
ThunkInfo thunk_info, PrimitiveType type,
std::optional<PrimitiveType> value_type,
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
absl::InlinedVector<BufferAllocation::Slice, 2> results,
BufferAllocation::Slice scratch, bool descending, int64_t batch_size,
absl::string_view platform_name)
absl::string_view platform_name) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<CubSortRunnerInterface> runner,
CubSortRunnerInterface::Create(type, value_type, platform_name));
return absl::WrapUnique<CubSortThunk>(
new CubSortThunk(thunk_info, std::move(runner), std::move(operands),
std::move(results), scratch, descending, batch_size));
}
CubSortThunk::CubSortThunk(
ThunkInfo thunk_info, std::unique_ptr<CubSortRunnerInterface> runner,
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
absl::InlinedVector<BufferAllocation::Slice, 2> results,
BufferAllocation::Slice scratch, bool descending, int64_t batch_size)
: Thunk(Thunk::kCubSort, thunk_info),
runner_(CubSortRunnerInterface::Create(type, value_type, platform_name)
.value()),
runner_(std::move(runner)),
operands_(std::move(operands)),
results_(std::move(results)),
scratch_(scratch),

View File

@ -54,12 +54,13 @@ class CubSortRunnerInterface {
class CubSortThunk : public Thunk {
public:
CubSortThunk(ThunkInfo thunk_info, PrimitiveType type,
std::optional<PrimitiveType> value_type,
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
absl::InlinedVector<BufferAllocation::Slice, 2> results,
BufferAllocation::Slice scratch, bool descending,
int64_t batch_size, absl::string_view platform_name);
static absl::StatusOr<std::unique_ptr<CubSortThunk>> Create(
ThunkInfo thunk_info, PrimitiveType type,
std::optional<PrimitiveType> value_type,
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
absl::InlinedVector<BufferAllocation::Slice, 2> results,
BufferAllocation::Slice scratch, bool descending, int64_t batch_size,
absl::string_view platform_name);
absl::Status ExecuteOnStream(const ExecuteParams& params) override {
return runner_->Run(params, this);
@ -72,6 +73,12 @@ class CubSortThunk : public Thunk {
int64_t batch_size() const { return batch_size_; }
private:
CubSortThunk(ThunkInfo thunk_info,
std::unique_ptr<CubSortRunnerInterface> runner,
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
absl::InlinedVector<BufferAllocation::Slice, 2> results,
BufferAllocation::Slice scratch, bool descending,
int64_t batch_size);
std::unique_ptr<CubSortRunnerInterface> runner_;
absl::InlinedVector<BufferAllocation::Slice, 2> operands_;
absl::InlinedVector<BufferAllocation::Slice, 2> results_;

View File

@ -1072,17 +1072,19 @@ absl::Status IrEmitterUnnested::EmitCubDeviceRadixSort(
TF_ASSIGN_OR_RETURN(xla::SortOptions options,
instr->backend_config<xla::SortOptions>());
const Shape& operand_shape = instr->operand(0)->shape();
auto thunk = std::make_unique<CubSortThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(
instr, ir_emitter_context_->GetNextThunkId()),
operand_shape.element_type(),
instr->operand_count() == 2
? std::optional(instr->operand(1)->shape().element_type())
: std::nullopt,
operands, results, scratch, options.descending(),
Product(operand_shape.dimensions()) /
operand_shape.dimensions(operand_shape.dimensions().size() - 1),
ir_emitter_context_->platform_name());
TF_ASSIGN_OR_RETURN(
std::unique_ptr<CubSortThunk> thunk,
CubSortThunk::Create(
Thunk::ThunkInfo::WithProfileAnnotation(
instr, ir_emitter_context_->GetNextThunkId()),
operand_shape.element_type(),
instr->operand_count() == 2
? std::optional(instr->operand(1)->shape().element_type())
: std::nullopt,
operands, results, scratch, options.descending(),
Product(operand_shape.dimensions()) /
operand_shape.dimensions(operand_shape.dimensions().size() - 1),
ir_emitter_context_->platform_name()));
AddThunkToThunkSequence(std::move(thunk));
return absl::OkStatus();
}