mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
[XLA] Add Reduce, DynamicSlice and DynamicSliceUpdate to HloEvaluator.
- Reduce is disabled explicitly for constant folding, as not all types of embedded computation can be currently supported by the evaluator. - Added support to evaluate HloModule to HloEvaluator. - Minor signature change to Evaluate(). PiperOrigin-RevId: 163299238
This commit is contained in:
parent
a524701723
commit
631a364cd1
|
|
@ -30,8 +30,8 @@ ExecutorExecutable::ExecutorExecutable(std::unique_ptr<HloModule> hlo_module)
|
|||
|
||||
ExecutorExecutable::~ExecutorExecutable() {}
|
||||
|
||||
static se::DeviceMemoryBase AllocateSingleOutput(sep::ExecutorExecutor* executor,
|
||||
const Literal& literal) {
|
||||
static se::DeviceMemoryBase AllocateSingleOutput(
|
||||
sep::ExecutorExecutor* executor, const Literal& literal) {
|
||||
int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape()));
|
||||
void* buf = executor->Allocate(size);
|
||||
const void* src = literal.InternalData();
|
||||
|
|
@ -39,8 +39,8 @@ static se::DeviceMemoryBase AllocateSingleOutput(sep::ExecutorExecutor* executor
|
|||
return se::DeviceMemoryBase(buf, size);
|
||||
}
|
||||
|
||||
static se::DeviceMemoryBase AllocateOutputBuffer(sep::ExecutorExecutor* executor,
|
||||
const Literal& literal) {
|
||||
static se::DeviceMemoryBase AllocateOutputBuffer(
|
||||
sep::ExecutorExecutor* executor, const Literal& literal) {
|
||||
const Shape& shape = literal.shape();
|
||||
if (shape.element_type() != xla::TUPLE) {
|
||||
return AllocateSingleOutput(executor, literal);
|
||||
|
|
@ -96,7 +96,7 @@ StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteOnStream(
|
|||
// Execute the graph using the evaluator
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> output,
|
||||
evaluator.Evaluate(computation, arg_literals_ptrs));
|
||||
evaluator.Evaluate(*computation, arg_literals_ptrs));
|
||||
|
||||
// Copy the result into the return buffer
|
||||
perftools::gputools::StreamExecutor* executor(stream->parent());
|
||||
|
|
@ -139,6 +139,5 @@ StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteAsyncOnStream(
|
|||
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
|
||||
}
|
||||
|
||||
|
||||
} // namespace executorplugin
|
||||
} // namespace xla
|
||||
|
|
|
|||
|
|
@ -51,9 +51,12 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
|
|||
computation->root_instruction() != instruction) {
|
||||
continue;
|
||||
}
|
||||
// Skip Constant and Parameter operation.
|
||||
// Skip Constant, Parameter, Reduce operation.
|
||||
// TODO(b/35975797): Enable Reduce operation once arbitary computation are
|
||||
// supported by the evaluator.
|
||||
if (instruction->opcode() == HloOpcode::kParameter ||
|
||||
instruction->opcode() == HloOpcode::kConstant) {
|
||||
instruction->opcode() == HloOpcode::kConstant ||
|
||||
instruction->opcode() == HloOpcode::kReduce) {
|
||||
continue;
|
||||
}
|
||||
// Skip instructions with non-constant operands.
|
||||
|
|
|
|||
|
|
@ -654,12 +654,262 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||
return Status::OK();
|
||||
};
|
||||
|
||||
Status HandleDynamicSlice(HloInstruction* dynamic_slice,
|
||||
HloInstruction* operand,
|
||||
HloInstruction* start_indices) override {
|
||||
auto result_shape = dynamic_slice->shape();
|
||||
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
|
||||
ShapeInference::InferDynamicSliceShape(
|
||||
operand->shape(), start_indices->shape(),
|
||||
dynamic_slice->dynamic_slice_sizes()));
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
|
||||
<< "return shape is set to: " << ShapeUtil::HumanString(result_shape)
|
||||
<< "but is inferred to be: "
|
||||
<< ShapeUtil::HumanString(inferred_return_shape);
|
||||
TF_RET_CHECK(
|
||||
primitive_util::IsIntegralType(start_indices->shape().element_type()));
|
||||
|
||||
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
|
||||
const Literal& start_indices_literal =
|
||||
parent_->GetEvaluatedLiteralFor(start_indices);
|
||||
|
||||
switch (start_indices->shape().element_type()) {
|
||||
case S32: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[dynamic_slice],
|
||||
DynamicSlice<int32>(operand_literal, start_indices_literal,
|
||||
result_shape));
|
||||
} break;
|
||||
case S64: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[dynamic_slice],
|
||||
DynamicSlice<int64>(operand_literal, start_indices_literal,
|
||||
result_shape));
|
||||
} break;
|
||||
case U32: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[dynamic_slice],
|
||||
DynamicSlice<uint32>(operand_literal, start_indices_literal,
|
||||
result_shape));
|
||||
} break;
|
||||
case U64: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[dynamic_slice],
|
||||
DynamicSlice<uint64>(operand_literal, start_indices_literal,
|
||||
result_shape));
|
||||
} break;
|
||||
default:
|
||||
LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for "
|
||||
"start_indices: "
|
||||
<< PrimitiveType_Name(start_indices->shape().element_type());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
|
||||
HloInstruction* operand,
|
||||
HloInstruction* update,
|
||||
HloInstruction* start_indices) override {
|
||||
auto result_shape = dynamic_update_slice->shape();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto inferred_return_shape,
|
||||
ShapeInference::InferDynamicUpdateSliceShape(
|
||||
operand->shape(), update->shape(), start_indices->shape()));
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
|
||||
<< "return shape is set to: " << ShapeUtil::HumanString(result_shape)
|
||||
<< "but is inferred to be: "
|
||||
<< ShapeUtil::HumanString(inferred_return_shape);
|
||||
TF_RET_CHECK(
|
||||
primitive_util::IsIntegralType(start_indices->shape().element_type()));
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape()));
|
||||
|
||||
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
|
||||
const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update);
|
||||
const Literal& start_indices_literal =
|
||||
parent_->GetEvaluatedLiteralFor(start_indices);
|
||||
|
||||
switch (start_indices->shape().element_type()) {
|
||||
case S32: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[dynamic_update_slice],
|
||||
DynamicUpdateSlice<int32>(operand_literal, update_literal,
|
||||
start_indices_literal));
|
||||
} break;
|
||||
case S64: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[dynamic_update_slice],
|
||||
DynamicUpdateSlice<int64>(operand_literal, update_literal,
|
||||
start_indices_literal));
|
||||
} break;
|
||||
case U32: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[dynamic_update_slice],
|
||||
DynamicUpdateSlice<uint32>(operand_literal, update_literal,
|
||||
start_indices_literal));
|
||||
} break;
|
||||
case U64: {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[dynamic_update_slice],
|
||||
DynamicUpdateSlice<uint64>(operand_literal, update_literal,
|
||||
start_indices_literal));
|
||||
} break;
|
||||
default:
|
||||
LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for "
|
||||
"start_indices: "
|
||||
<< PrimitiveType_Name(start_indices->shape().element_type());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
Status HandleReduce(HloInstruction* reduce, HloInstruction* arg,
|
||||
HloInstruction* init_value,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
HloComputation* function) override {
|
||||
TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) ==
|
||||
ShapeUtil::Rank(arg->shape()) - dimensions.size());
|
||||
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
|
||||
ShapeInference::InferReduceShape(
|
||||
/*arg=*/arg->shape(),
|
||||
/*init_value=*/init_value->shape(),
|
||||
/*dimensions_to_reduce=*/dimensions,
|
||||
/*to_apply=*/function->ComputeProgramShape()));
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape))
|
||||
<< "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
|
||||
<< "but is inferred to be: "
|
||||
<< ShapeUtil::HumanString(inferred_return_shape);
|
||||
|
||||
const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg);
|
||||
VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString();
|
||||
const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value);
|
||||
VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString();
|
||||
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
|
||||
auto init_scalar = init_literal.Get<ReturnT>({});
|
||||
|
||||
auto result = Literal::CreateFromShape(reduce->shape());
|
||||
|
||||
const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions());
|
||||
std::vector<int64> arg_dim_steps(arg_dimensions.size());
|
||||
std::vector<int64> arg_dim_counts(arg_dimensions.size());
|
||||
for (const int64 dim : dimensions) {
|
||||
arg_dim_steps[dim] = 1;
|
||||
arg_dim_counts[dim] = arg_dimensions[dim];
|
||||
}
|
||||
|
||||
// Create mapping from result index to arg index.
|
||||
const int64 result_rank = ShapeUtil::Rank(result->shape());
|
||||
int64 result_dim = 0;
|
||||
std::vector<int64> result_to_arg_index(result_rank);
|
||||
for (int64 i = 0; i < arg_dimensions.size(); ++i) {
|
||||
if (arg_dim_steps[i] == 0) {
|
||||
result_to_arg_index[result_dim] = i;
|
||||
++result_dim;
|
||||
}
|
||||
}
|
||||
|
||||
// For each resulting dimension, calculate and assign computed value.
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
ReturnT result_val = init_scalar;
|
||||
|
||||
std::vector<int64> base(arg_dimensions.size());
|
||||
for (int64 i = 0; i < multi_index.size(); ++i) {
|
||||
base[result_to_arg_index[i]] = multi_index[i];
|
||||
}
|
||||
|
||||
auto func = [&](const std::vector<int64>& input_index) {
|
||||
auto curr_val = arg_literal.Get<ReturnT>(input_index);
|
||||
|
||||
// Evaluate computation with specified literal operands.
|
||||
auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
|
||||
auto result_val_literal = Literal::CreateR0<ReturnT>(result_val);
|
||||
std::vector<const Literal*> args = {curr_val_literal.get(),
|
||||
result_val_literal.get()};
|
||||
|
||||
// We need a new visitor for each evaluation, so that the same
|
||||
// computation can be visited more than once (with different
|
||||
// inputs).
|
||||
HloEvaluator embedded_evaluator;
|
||||
std::unique_ptr<Literal> computed_result =
|
||||
embedded_evaluator.Evaluate(*function, args)
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
// Assign computed result to result_val.
|
||||
result_val = computed_result->Get<ReturnT>({});
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
|
||||
arg_dim_steps, func);
|
||||
|
||||
return result_val;
|
||||
}));
|
||||
|
||||
parent_->evaluated_[reduce] = std::move(result);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
Status Preprocess(HloInstruction* hlo) override {
|
||||
VLOG(2) << hlo->ToString();
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
private:
|
||||
template <typename IndexT>
|
||||
StatusOr<std::unique_ptr<Literal>> DynamicSlice(
|
||||
const Literal& operand_literal, const Literal& start_indices_literal,
|
||||
const Shape& result_shape) {
|
||||
const auto& start_indices_typed =
|
||||
start_indices_literal.GetArraySlice<IndexT>();
|
||||
std::vector<int64> start(start_indices_typed.begin(),
|
||||
start_indices_typed.end());
|
||||
|
||||
std::vector<int64> operand_indices(start.size(), 0);
|
||||
|
||||
auto result = Literal::CreateFromShape(result_shape);
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
|
||||
std::transform(multi_index.begin(), multi_index.end(), start.begin(),
|
||||
operand_indices.begin(), std::plus<int64>());
|
||||
|
||||
return operand_literal.Get<ReturnT>(operand_indices);
|
||||
}));
|
||||
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
template <typename IndexT>
|
||||
StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
|
||||
const Literal& operand_literal, const Literal& update_literal,
|
||||
const Literal& start_indices_literal) {
|
||||
const auto& start_indices_typed =
|
||||
start_indices_literal.GetArraySlice<IndexT>();
|
||||
const std::vector<int64> start(start_indices_typed.begin(),
|
||||
start_indices_typed.end());
|
||||
|
||||
auto result = MakeUnique<Literal>(operand_literal);
|
||||
std::vector<int64> result_index(ShapeUtil::Rank(result->shape()), 0);
|
||||
|
||||
auto func = [&](const std::vector<int64>& update_index) {
|
||||
std::transform(update_index.begin(), update_index.end(), start.begin(),
|
||||
result_index.begin(), std::plus<int64>());
|
||||
|
||||
result->Set<ReturnT>(result_index,
|
||||
update_literal.Get<ReturnT>(update_index));
|
||||
return true;
|
||||
};
|
||||
|
||||
std::vector<int64> base(update_literal.shape().dimensions_size(), 0);
|
||||
std::vector<int64> step(update_literal.shape().dimensions_size(), 1);
|
||||
ShapeUtil::ForEachIndex(update_literal.shape(), base,
|
||||
AsInt64Slice(update_literal.shape().dimensions()),
|
||||
step, func);
|
||||
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
|
||||
HloInstruction* instruction,
|
||||
const std::function<ReturnT(ReturnT)>& unary_op) {
|
||||
|
|
@ -771,14 +1021,28 @@ HloEvaluator::HloEvaluator() {
|
|||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
HloComputation* computation,
|
||||
tensorflow::gtl::ArraySlice<const Literal*> args) {
|
||||
arg_literals_ = args;
|
||||
const HloModule& module,
|
||||
tensorflow::gtl::ArraySlice<const Literal*> arg_literals) {
|
||||
XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString());
|
||||
|
||||
arg_literals_ = arg_literals;
|
||||
evaluated_.clear();
|
||||
|
||||
TF_RETURN_IF_ERROR(computation->Accept(this));
|
||||
TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this));
|
||||
|
||||
return MakeUnique<Literal>(
|
||||
GetEvaluatedLiteralFor(computation->root_instruction()));
|
||||
GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()));
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
const HloComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<const Literal*> arg_literals) {
|
||||
arg_literals_ = arg_literals;
|
||||
evaluated_.clear();
|
||||
|
||||
TF_RETURN_IF_ERROR(computation.Accept(this));
|
||||
return MakeUnique<Literal>(
|
||||
GetEvaluatedLiteralFor(computation.root_instruction()));
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
|
|
@ -930,7 +1194,8 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite,
|
|||
break;
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "unknown/unhandled primitive type.";
|
||||
LOG(FATAL) << "HandleIsFinite: unknown/unhandled primitive type: "
|
||||
<< PrimitiveType_Name(operand->shape().element_type());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -1009,7 +1274,8 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode,
|
|||
Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal));
|
||||
} break;
|
||||
default:
|
||||
LOG(FATAL) << "unknown primitive type.";
|
||||
LOG(FATAL) << "HandleCompare: unknown primitive type: "
|
||||
<< PrimitiveType_Name(lhs->shape().element_type());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
|
@ -36,9 +37,17 @@ namespace xla {
|
|||
class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
HloEvaluator();
|
||||
// Evaluates a HLO computation and an array of pointers to literals.
|
||||
// Return the evaluated result as literal if successful.
|
||||
// Precondition: argument literals are corresponds to the input computation's
|
||||
// Evaluates an HLO module and an array of pointers to literals.
|
||||
// Returns the evaluated result as a literal if successful.
|
||||
// Precondition: argument literals correspond to each input computation's
|
||||
// parameters in their post-ordering. See comment below for example.
|
||||
StatusOr<std::unique_ptr<Literal>> Evaluate(
|
||||
const HloModule& module,
|
||||
tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
|
||||
|
||||
// Evaluates an HLO computation and an array of pointers to literals.
|
||||
// Returns the evaluated result as a literal if successful.
|
||||
// Precondition: argument literals correspond to the input computation's
|
||||
// parameters in their post-ordering. For e.g., consider the following graph:
|
||||
//
|
||||
// *
|
||||
|
|
@ -51,7 +60,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
|||
// The input literals array will have its first literal map to Parameter0 and
|
||||
// the second map to Parameter1.
|
||||
StatusOr<std::unique_ptr<Literal>> Evaluate(
|
||||
HloComputation* computation,
|
||||
const HloComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
|
||||
|
||||
// Evaluates a single HLO instruction and an array of pointers to literals.
|
||||
|
|
|
|||
|
|
@ -188,7 +188,7 @@ TEST_F(HloEvaluatorTest, DoesAbs) {
|
|||
// Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
|
||||
// constant operands.
|
||||
TEST_F(HloEvaluatorTest, DoesTraverseInstructions) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloComputation::Builder b(TestName());
|
||||
auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
|
||||
auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
|
||||
auto rhs2 = Literal::CreateR2<int64>({{1, -20}, {-100, 4}});
|
||||
|
|
@ -205,9 +205,9 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) {
|
|||
auto root_instruction = HloInstruction::CreateBinary(
|
||||
shape, HloOpcode::kAdd, lhs_instruction.get(), param_rhs2.get());
|
||||
|
||||
builder.AddInstruction(std::move(root_instruction));
|
||||
b.AddInstruction(std::move(root_instruction));
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(builder.Build().get(), args).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), args).ConsumeValueOrDie();
|
||||
|
||||
auto expected = Literal::CreateR2<int64>({{4, -16}, {-196, 12}});
|
||||
|
||||
|
|
@ -216,22 +216,22 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) {
|
|||
|
||||
// Verifies Reshape operation is correctly evaluated.
|
||||
TEST_F(HloEvaluatorTest, DoesReshape) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloComputation::Builder b(TestName());
|
||||
const int64 dimensions[] = {11, 8, 7, 5, 9};
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto literal,
|
||||
LiteralTestUtil::CreateRandomLiteral<F32>(
|
||||
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
|
||||
auto literal_clone = literal->CloneToUnique();
|
||||
HloInstruction* literal_instruction = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(literal)));
|
||||
HloInstruction* literal_instruction =
|
||||
b.AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
|
||||
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
|
||||
const int64 permutation[] = {1, 2, 0, 4, 3};
|
||||
builder.AddInstruction(
|
||||
b.AddInstruction(
|
||||
HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
|
||||
result->EachCell<NativeT>(
|
||||
|
|
@ -243,24 +243,24 @@ TEST_F(HloEvaluatorTest, DoesReshape) {
|
|||
|
||||
// Verifies Broadcast operation is correctly evaluated.
|
||||
TEST_F(HloEvaluatorTest, DoesBroadcast) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloComputation::Builder b(TestName());
|
||||
auto input_literal = Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
|
||||
auto output_literal = Literal::CreateR3<int32>(
|
||||
{{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}});
|
||||
HloInstruction* literal_instruction = builder.AddInstruction(
|
||||
HloInstruction* literal_instruction = b.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(input_literal)));
|
||||
|
||||
builder.AddInstruction(HloInstruction::CreateBroadcast(
|
||||
b.AddInstruction(HloInstruction::CreateBroadcast(
|
||||
output_literal->shape(), literal_instruction, {1, 2}));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*result, *output_literal);
|
||||
}
|
||||
|
||||
TEST_F(HloEvaluatorTest, ConvertWithSameLayout) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloComputation::Builder b(TestName());
|
||||
|
||||
auto input_literal = Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
|
||||
auto expected =
|
||||
|
|
@ -268,19 +268,18 @@ TEST_F(HloEvaluatorTest, ConvertWithSameLayout) {
|
|||
ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
|
||||
expected->shape()));
|
||||
|
||||
HloInstruction* constant = builder.AddInstruction(
|
||||
HloInstruction* constant = b.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(input_literal)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConvert(expected->shape(), constant));
|
||||
b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*result, *expected);
|
||||
}
|
||||
|
||||
TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloComputation::Builder b(TestName());
|
||||
|
||||
auto input_literal = Literal::CreateR2WithLayout<int32>(
|
||||
{{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
|
||||
|
|
@ -289,13 +288,12 @@ TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) {
|
|||
ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
|
||||
expected->shape()));
|
||||
|
||||
HloInstruction* constant = builder.AddInstruction(
|
||||
HloInstruction* constant = b.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(input_literal)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConvert(expected->shape(), constant));
|
||||
b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*result, *expected);
|
||||
}
|
||||
|
|
@ -355,7 +353,7 @@ TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
|
|||
shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
auto expected_array = MakeUnique<Array4D<float>>(8, 5, 1, 1);
|
||||
expected_array->Fill(kPadValue);
|
||||
|
|
@ -398,7 +396,7 @@ TEST_F(HloEvaluatorTest, NegativePadding2D) {
|
|||
r2_padding_on_dim0_dim1));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
// f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
|
||||
auto expected_array = MakeUnique<Array2D<float>>(1, 5);
|
||||
|
|
@ -442,7 +440,7 @@ TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
|
|||
r2_padding_on_dim0_dim1));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
auto expected_array = MakeUnique<Array2D<float>>(0, 9);
|
||||
auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
|
||||
|
|
@ -477,7 +475,7 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank1) {
|
|||
shape, HloOpcode::kDot, lhs_instruction, rhs_instruction));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
// clang-format off
|
||||
auto expected_array = Array2D<float>({
|
||||
|
|
@ -519,7 +517,7 @@ TEST_F(HloEvaluatorTest, DotRank1AndRank2) {
|
|||
shape, HloOpcode::kDot, lhs_instruction, rhs_instruction));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
auto expected = Literal::CreateR1<float>({22.f, 28.f});
|
||||
|
||||
|
|
@ -559,10 +557,13 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank2) {
|
|||
shape, HloOpcode::kDot, lhs_instruction, rhs_instruction));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
auto expected_array = Array2D<float>({
|
||||
{22.f, 28.f}, {58.f, 76.f}, {94.f, 124.f}, {130.f, 172.f},
|
||||
{22.f, 28.f},
|
||||
{58.f, 76.f},
|
||||
{94.f, 124.f},
|
||||
{130.f, 172.f},
|
||||
});
|
||||
auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
|
||||
|
||||
|
|
@ -606,7 +607,7 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) {
|
|||
shape, lhs_instruction, rhs_instruction, window, dnums));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
|
||||
auto expected = Literal::CreateR3FromArray3D<float>(expected_array);
|
||||
|
|
@ -660,7 +661,7 @@ TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
|
|||
shape, lhs_instruction, rhs_instruction, window, dnums));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
Array4D<float> expected_array(1, 1, 4, 4);
|
||||
// clang-format off
|
||||
|
|
@ -736,7 +737,7 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) {
|
|||
shape, lhs_instruction, rhs_instruction, window, dnums));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
// clang-format off
|
||||
// Result dimensions: [feature=1, height=1, batch=1, width=2]
|
||||
|
|
@ -793,7 +794,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
|
|||
shape, lhs_instruction, rhs_instruction, window, dnums));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
Array4D<float> expected_array(1, 1, 7, 7);
|
||||
expected_array.FillWithYX(Array2D<float>({
|
||||
|
|
@ -856,7 +857,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
|
|||
shape, lhs_instruction, rhs_instruction, window, dnums));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
Array4D<float> expected_array(1, 1, 8, 8);
|
||||
expected_array.FillWithYX(Array2D<float>({
|
||||
|
|
@ -927,7 +928,7 @@ TEST_F(HloEvaluatorTest,
|
|||
shape, lhs_instruction, rhs_instruction, window, dnums));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
Array4D<float> expected_array(1, 1, 9, 3);
|
||||
expected_array.FillWithYX(Array2D<float>({
|
||||
|
|
@ -946,5 +947,115 @@ TEST_F(HloEvaluatorTest,
|
|||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
}
|
||||
|
||||
TEST_F(HloEvaluatorTest, ReduceAdd) {
|
||||
HloComputation::Builder b(TestName());
|
||||
|
||||
// arg:
|
||||
// f32[2,3] {
|
||||
// { 1, 2, 3 },
|
||||
// { 5, 6, 7 },
|
||||
// }
|
||||
auto arg_array = MakeUnique<Array2D<float>>(2, 3);
|
||||
arg_array->FillUnique(1.0f);
|
||||
auto arg_literal = Literal::CreateR2FromArray2D<float>(*arg_array);
|
||||
|
||||
HloInstruction* arg_instruction =
|
||||
b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
|
||||
|
||||
auto init_value = b.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
|
||||
|
||||
HloComputation::Builder add_computation("add");
|
||||
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||
auto param_lhs = add_computation.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
|
||||
auto param_rhs = add_computation.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
|
||||
add_computation.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
|
||||
auto add_func = add_computation.Build();
|
||||
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {2});
|
||||
b.AddInstruction(HloInstruction::CreateReduce(
|
||||
shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{1},
|
||||
add_func.get()));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
auto expected = Literal::CreateR1<float>({6, 18});
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
}
|
||||
|
||||
TEST_F(HloEvaluatorTest, DynamicSlice) {
|
||||
HloComputation::Builder b(TestName());
|
||||
|
||||
// arg:
|
||||
// f32[2,4] {
|
||||
// { 1, 2, 3, 4 },
|
||||
// { 5, 6, 7, 8 },
|
||||
// }
|
||||
auto operand_array = MakeUnique<Array2D<float>>(2, 4);
|
||||
operand_array->FillUnique(1.0f);
|
||||
auto operand_literal = Literal::CreateR2FromArray2D<float>(*operand_array);
|
||||
|
||||
HloInstruction* operand = b.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(operand_literal)));
|
||||
|
||||
auto start_indices = b.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<int32>({0, 1})));
|
||||
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||
b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand,
|
||||
start_indices, {2, 3}));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
auto expected = Literal::CreateR2<float>({
|
||||
{2, 3, 4},
|
||||
{6, 7, 8},
|
||||
});
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
}
|
||||
|
||||
TEST_F(HloEvaluatorTest, DynamicSliceUpdate) {
|
||||
HloComputation::Builder b(TestName());
|
||||
|
||||
// arg:
|
||||
// f32[2,3] {
|
||||
// { 1, 2, 3 },
|
||||
// { 5, 6, 7 },
|
||||
// }
|
||||
auto operand_array = MakeUnique<Array2D<double>>(2, 3);
|
||||
operand_array->FillUnique(1.0);
|
||||
auto operand_literal = Literal::CreateR2FromArray2D<double>(*operand_array);
|
||||
|
||||
HloInstruction* operand = b.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(operand_literal)));
|
||||
|
||||
auto start_indices = b.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR1<int64>({0, 1})));
|
||||
|
||||
auto update = b.AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR2<double>({{-2.0, -3.0}, {-6.0, -7.0}})));
|
||||
|
||||
Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
|
||||
b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
shape, operand, update, start_indices));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
auto expected = Literal::CreateR2<double>({
|
||||
{1, -2, -3},
|
||||
{5, -6, -7},
|
||||
});
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user