mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Use factory function to create CubSortThunk
The `CubSortThunk` constructor was calling a function that returns a `absl::StatusOr`, and ignoring non-ok statuses and just accessing the value. Presumably in prod the status is always ok, but making this failure case explicit. PiperOrigin-RevId: 826410861
This commit is contained in:
parent
adfd891fde
commit
e32f20dd91
|
|
@ -684,6 +684,7 @@ cc_library(
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/log",
|
"@com_google_absl//absl/log",
|
||||||
"@com_google_absl//absl/log:check",
|
"@com_google_absl//absl/log:check",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "absl/log/check.h"
|
#include "absl/log/check.h"
|
||||||
#include "absl/log/log.h"
|
#include "absl/log/log.h"
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
|
@ -283,16 +284,29 @@ CubSortRunnerInterface::Create(PrimitiveType type,
|
||||||
: CreateCubSortRunner(type, platform_name);
|
: CreateCubSortRunner(type, platform_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
CubSortThunk::CubSortThunk(
|
absl::StatusOr<std::unique_ptr<CubSortThunk>> CubSortThunk::Create(
|
||||||
ThunkInfo thunk_info, PrimitiveType type,
|
ThunkInfo thunk_info, PrimitiveType type,
|
||||||
std::optional<PrimitiveType> value_type,
|
std::optional<PrimitiveType> value_type,
|
||||||
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
|
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
|
||||||
absl::InlinedVector<BufferAllocation::Slice, 2> results,
|
absl::InlinedVector<BufferAllocation::Slice, 2> results,
|
||||||
BufferAllocation::Slice scratch, bool descending, int64_t batch_size,
|
BufferAllocation::Slice scratch, bool descending, int64_t batch_size,
|
||||||
absl::string_view platform_name)
|
absl::string_view platform_name) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::unique_ptr<CubSortRunnerInterface> runner,
|
||||||
|
CubSortRunnerInterface::Create(type, value_type, platform_name));
|
||||||
|
|
||||||
|
return absl::WrapUnique<CubSortThunk>(
|
||||||
|
new CubSortThunk(thunk_info, std::move(runner), std::move(operands),
|
||||||
|
std::move(results), scratch, descending, batch_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
CubSortThunk::CubSortThunk(
|
||||||
|
ThunkInfo thunk_info, std::unique_ptr<CubSortRunnerInterface> runner,
|
||||||
|
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
|
||||||
|
absl::InlinedVector<BufferAllocation::Slice, 2> results,
|
||||||
|
BufferAllocation::Slice scratch, bool descending, int64_t batch_size)
|
||||||
: Thunk(Thunk::kCubSort, thunk_info),
|
: Thunk(Thunk::kCubSort, thunk_info),
|
||||||
runner_(CubSortRunnerInterface::Create(type, value_type, platform_name)
|
runner_(std::move(runner)),
|
||||||
.value()),
|
|
||||||
operands_(std::move(operands)),
|
operands_(std::move(operands)),
|
||||||
results_(std::move(results)),
|
results_(std::move(results)),
|
||||||
scratch_(scratch),
|
scratch_(scratch),
|
||||||
|
|
|
||||||
|
|
@ -54,12 +54,13 @@ class CubSortRunnerInterface {
|
||||||
|
|
||||||
class CubSortThunk : public Thunk {
|
class CubSortThunk : public Thunk {
|
||||||
public:
|
public:
|
||||||
CubSortThunk(ThunkInfo thunk_info, PrimitiveType type,
|
static absl::StatusOr<std::unique_ptr<CubSortThunk>> Create(
|
||||||
|
ThunkInfo thunk_info, PrimitiveType type,
|
||||||
std::optional<PrimitiveType> value_type,
|
std::optional<PrimitiveType> value_type,
|
||||||
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
|
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
|
||||||
absl::InlinedVector<BufferAllocation::Slice, 2> results,
|
absl::InlinedVector<BufferAllocation::Slice, 2> results,
|
||||||
BufferAllocation::Slice scratch, bool descending,
|
BufferAllocation::Slice scratch, bool descending, int64_t batch_size,
|
||||||
int64_t batch_size, absl::string_view platform_name);
|
absl::string_view platform_name);
|
||||||
|
|
||||||
absl::Status ExecuteOnStream(const ExecuteParams& params) override {
|
absl::Status ExecuteOnStream(const ExecuteParams& params) override {
|
||||||
return runner_->Run(params, this);
|
return runner_->Run(params, this);
|
||||||
|
|
@ -72,6 +73,12 @@ class CubSortThunk : public Thunk {
|
||||||
int64_t batch_size() const { return batch_size_; }
|
int64_t batch_size() const { return batch_size_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
CubSortThunk(ThunkInfo thunk_info,
|
||||||
|
std::unique_ptr<CubSortRunnerInterface> runner,
|
||||||
|
absl::InlinedVector<BufferAllocation::Slice, 2> operands,
|
||||||
|
absl::InlinedVector<BufferAllocation::Slice, 2> results,
|
||||||
|
BufferAllocation::Slice scratch, bool descending,
|
||||||
|
int64_t batch_size);
|
||||||
std::unique_ptr<CubSortRunnerInterface> runner_;
|
std::unique_ptr<CubSortRunnerInterface> runner_;
|
||||||
absl::InlinedVector<BufferAllocation::Slice, 2> operands_;
|
absl::InlinedVector<BufferAllocation::Slice, 2> operands_;
|
||||||
absl::InlinedVector<BufferAllocation::Slice, 2> results_;
|
absl::InlinedVector<BufferAllocation::Slice, 2> results_;
|
||||||
|
|
|
||||||
|
|
@ -1072,7 +1072,9 @@ absl::Status IrEmitterUnnested::EmitCubDeviceRadixSort(
|
||||||
TF_ASSIGN_OR_RETURN(xla::SortOptions options,
|
TF_ASSIGN_OR_RETURN(xla::SortOptions options,
|
||||||
instr->backend_config<xla::SortOptions>());
|
instr->backend_config<xla::SortOptions>());
|
||||||
const Shape& operand_shape = instr->operand(0)->shape();
|
const Shape& operand_shape = instr->operand(0)->shape();
|
||||||
auto thunk = std::make_unique<CubSortThunk>(
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
std::unique_ptr<CubSortThunk> thunk,
|
||||||
|
CubSortThunk::Create(
|
||||||
Thunk::ThunkInfo::WithProfileAnnotation(
|
Thunk::ThunkInfo::WithProfileAnnotation(
|
||||||
instr, ir_emitter_context_->GetNextThunkId()),
|
instr, ir_emitter_context_->GetNextThunkId()),
|
||||||
operand_shape.element_type(),
|
operand_shape.element_type(),
|
||||||
|
|
@ -1082,7 +1084,7 @@ absl::Status IrEmitterUnnested::EmitCubDeviceRadixSort(
|
||||||
operands, results, scratch, options.descending(),
|
operands, results, scratch, options.descending(),
|
||||||
Product(operand_shape.dimensions()) /
|
Product(operand_shape.dimensions()) /
|
||||||
operand_shape.dimensions(operand_shape.dimensions().size() - 1),
|
operand_shape.dimensions(operand_shape.dimensions().size() - 1),
|
||||||
ir_emitter_context_->platform_name());
|
ir_emitter_context_->platform_name()));
|
||||||
AddThunkToThunkSequence(std::move(thunk));
|
AddThunkToThunkSequence(std::move(thunk));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user