mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Use factory function to create CubSortThunk
The `CubSortThunk` constructor was calling a function that returns a `absl::StatusOr`, and ignoring non-ok statuses and just accessing the value. Presumably in prod the status is always ok, but making this failure case explicit. PiperOrigin-RevId: 826410861
This commit is contained in:
parent
adfd891fde
commit
e32f20dd91
|
|
@ -684,6 +684,7 @@ cc_library(
|
|||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
|
|
@ -283,16 +284,29 @@ CubSortRunnerInterface::Create(PrimitiveType type,
|
|||
: CreateCubSortRunner(type, platform_name);
|
||||
}
|
||||
|
||||
CubSortThunk::CubSortThunk(
|
||||
absl::StatusOr<std::unique_ptr<CubSortThunk>> CubSortThunk::Create(
|
||||
ThunkInfo thunk_info, PrimitiveType type,
|
||||
std::optional<PrimitiveType> value_type,
|
||||
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
|
||||
absl::InlinedVector<BufferAllocation::Slice, 2> results,
|
||||
BufferAllocation::Slice scratch, bool descending, int64_t batch_size,
|
||||
absl::string_view platform_name)
|
||||
absl::string_view platform_name) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<CubSortRunnerInterface> runner,
|
||||
CubSortRunnerInterface::Create(type, value_type, platform_name));
|
||||
|
||||
return absl::WrapUnique<CubSortThunk>(
|
||||
new CubSortThunk(thunk_info, std::move(runner), std::move(operands),
|
||||
std::move(results), scratch, descending, batch_size));
|
||||
}
|
||||
|
||||
CubSortThunk::CubSortThunk(
|
||||
ThunkInfo thunk_info, std::unique_ptr<CubSortRunnerInterface> runner,
|
||||
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
|
||||
absl::InlinedVector<BufferAllocation::Slice, 2> results,
|
||||
BufferAllocation::Slice scratch, bool descending, int64_t batch_size)
|
||||
: Thunk(Thunk::kCubSort, thunk_info),
|
||||
runner_(CubSortRunnerInterface::Create(type, value_type, platform_name)
|
||||
.value()),
|
||||
runner_(std::move(runner)),
|
||||
operands_(std::move(operands)),
|
||||
results_(std::move(results)),
|
||||
scratch_(scratch),
|
||||
|
|
|
|||
|
|
@ -54,12 +54,13 @@ class CubSortRunnerInterface {
|
|||
|
||||
class CubSortThunk : public Thunk {
|
||||
public:
|
||||
CubSortThunk(ThunkInfo thunk_info, PrimitiveType type,
|
||||
static absl::StatusOr<std::unique_ptr<CubSortThunk>> Create(
|
||||
ThunkInfo thunk_info, PrimitiveType type,
|
||||
std::optional<PrimitiveType> value_type,
|
||||
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
|
||||
absl::InlinedVector<BufferAllocation::Slice, 2> results,
|
||||
BufferAllocation::Slice scratch, bool descending,
|
||||
int64_t batch_size, absl::string_view platform_name);
|
||||
BufferAllocation::Slice scratch, bool descending, int64_t batch_size,
|
||||
absl::string_view platform_name);
|
||||
|
||||
absl::Status ExecuteOnStream(const ExecuteParams& params) override {
|
||||
return runner_->Run(params, this);
|
||||
|
|
@ -72,6 +73,12 @@ class CubSortThunk : public Thunk {
|
|||
int64_t batch_size() const { return batch_size_; }
|
||||
|
||||
private:
|
||||
CubSortThunk(ThunkInfo thunk_info,
|
||||
std::unique_ptr<CubSortRunnerInterface> runner,
|
||||
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
|
||||
absl::InlinedVector<BufferAllocation::Slice, 2> results,
|
||||
BufferAllocation::Slice scratch, bool descending,
|
||||
int64_t batch_size);
|
||||
std::unique_ptr<CubSortRunnerInterface> runner_;
|
||||
absl::InlinedVector<BufferAllocation::Slice, 2> operands_;
|
||||
absl::InlinedVector<BufferAllocation::Slice, 2> results_;
|
||||
|
|
|
|||
|
|
@ -1072,7 +1072,9 @@ absl::Status IrEmitterUnnested::EmitCubDeviceRadixSort(
|
|||
TF_ASSIGN_OR_RETURN(xla::SortOptions options,
|
||||
instr->backend_config<xla::SortOptions>());
|
||||
const Shape& operand_shape = instr->operand(0)->shape();
|
||||
auto thunk = std::make_unique<CubSortThunk>(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<CubSortThunk> thunk,
|
||||
CubSortThunk::Create(
|
||||
Thunk::ThunkInfo::WithProfileAnnotation(
|
||||
instr, ir_emitter_context_->GetNextThunkId()),
|
||||
operand_shape.element_type(),
|
||||
|
|
@ -1082,7 +1084,7 @@ absl::Status IrEmitterUnnested::EmitCubDeviceRadixSort(
|
|||
operands, results, scratch, options.descending(),
|
||||
Product(operand_shape.dimensions()) /
|
||||
operand_shape.dimensions(operand_shape.dimensions().size() - 1),
|
||||
ir_emitter_context_->platform_name());
|
||||
ir_emitter_context_->platform_name()));
|
||||
AddThunkToThunkSequence(std::move(thunk));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user