[XLA] Add Reduce, DynamicSlice and DynamicSliceUpdate to HloEvaluator.

- Reduce is disabled explicitly for constant folding, as not all types of
embedded computation can be currently supported by the evaluator.

- Added support to evaluate HloModule to HloEvaluator.

- Minor signature change to Evaluate().

PiperOrigin-RevId: 163299238
This commit is contained in:
Kay Zhu 2017-07-26 21:55:47 -07:00 committed by TensorFlower Gardener
parent a524701723
commit 631a364cd1
5 changed files with 442 additions and 54 deletions

View File

@ -30,8 +30,8 @@ ExecutorExecutable::ExecutorExecutable(std::unique_ptr<HloModule> hlo_module)
ExecutorExecutable::~ExecutorExecutable() {}
static se::DeviceMemoryBase AllocateSingleOutput(sep::ExecutorExecutor* executor,
const Literal& literal) {
static se::DeviceMemoryBase AllocateSingleOutput(
sep::ExecutorExecutor* executor, const Literal& literal) {
int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape()));
void* buf = executor->Allocate(size);
const void* src = literal.InternalData();
@ -39,8 +39,8 @@ static se::DeviceMemoryBase AllocateSingleOutput(sep::ExecutorExecutor* executor
return se::DeviceMemoryBase(buf, size);
}
static se::DeviceMemoryBase AllocateOutputBuffer(sep::ExecutorExecutor* executor,
const Literal& literal) {
static se::DeviceMemoryBase AllocateOutputBuffer(
sep::ExecutorExecutor* executor, const Literal& literal) {
const Shape& shape = literal.shape();
if (shape.element_type() != xla::TUPLE) {
return AllocateSingleOutput(executor, literal);
@ -96,7 +96,7 @@ StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteOnStream(
// Execute the graph using the evaluator
HloEvaluator evaluator;
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> output,
evaluator.Evaluate(computation, arg_literals_ptrs));
evaluator.Evaluate(*computation, arg_literals_ptrs));
// Copy the result into the return buffer
perftools::gputools::StreamExecutor* executor(stream->parent());
@ -139,6 +139,5 @@ StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteAsyncOnStream(
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
}
} // namespace executorplugin
} // namespace xla

View File

@ -51,9 +51,12 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
computation->root_instruction() != instruction) {
continue;
}
// Skip Constant and Parameter operation.
// Skip Constant, Parameter, Reduce operation.
// TODO(b/35975797): Enable Reduce operation once arbitary computation are
// supported by the evaluator.
if (instruction->opcode() == HloOpcode::kParameter ||
instruction->opcode() == HloOpcode::kConstant) {
instruction->opcode() == HloOpcode::kConstant ||
instruction->opcode() == HloOpcode::kReduce) {
continue;
}
// Skip instructions with non-constant operands.

View File

@ -654,12 +654,262 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
};
Status HandleDynamicSlice(HloInstruction* dynamic_slice,
HloInstruction* operand,
HloInstruction* start_indices) override {
auto result_shape = dynamic_slice->shape();
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
ShapeInference::InferDynamicSliceShape(
operand->shape(), start_indices->shape(),
dynamic_slice->dynamic_slice_sizes()));
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape is set to: " << ShapeUtil::HumanString(result_shape)
<< "but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
TF_RET_CHECK(
primitive_util::IsIntegralType(start_indices->shape().element_type()));
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
const Literal& start_indices_literal =
parent_->GetEvaluatedLiteralFor(start_indices);
switch (start_indices->shape().element_type()) {
case S32: {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[dynamic_slice],
DynamicSlice<int32>(operand_literal, start_indices_literal,
result_shape));
} break;
case S64: {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[dynamic_slice],
DynamicSlice<int64>(operand_literal, start_indices_literal,
result_shape));
} break;
case U32: {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[dynamic_slice],
DynamicSlice<uint32>(operand_literal, start_indices_literal,
result_shape));
} break;
case U64: {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[dynamic_slice],
DynamicSlice<uint64>(operand_literal, start_indices_literal,
result_shape));
} break;
default:
LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for "
"start_indices: "
<< PrimitiveType_Name(start_indices->shape().element_type());
}
return Status::OK();
};
Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
HloInstruction* operand,
HloInstruction* update,
HloInstruction* start_indices) override {
auto result_shape = dynamic_update_slice->shape();
TF_ASSIGN_OR_RETURN(
auto inferred_return_shape,
ShapeInference::InferDynamicUpdateSliceShape(
operand->shape(), update->shape(), start_indices->shape()));
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape is set to: " << ShapeUtil::HumanString(result_shape)
<< "but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
TF_RET_CHECK(
primitive_util::IsIntegralType(start_indices->shape().element_type()));
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape()));
const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update);
const Literal& start_indices_literal =
parent_->GetEvaluatedLiteralFor(start_indices);
switch (start_indices->shape().element_type()) {
case S32: {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[dynamic_update_slice],
DynamicUpdateSlice<int32>(operand_literal, update_literal,
start_indices_literal));
} break;
case S64: {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[dynamic_update_slice],
DynamicUpdateSlice<int64>(operand_literal, update_literal,
start_indices_literal));
} break;
case U32: {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[dynamic_update_slice],
DynamicUpdateSlice<uint32>(operand_literal, update_literal,
start_indices_literal));
} break;
case U64: {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[dynamic_update_slice],
DynamicUpdateSlice<uint64>(operand_literal, update_literal,
start_indices_literal));
} break;
default:
LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for "
"start_indices: "
<< PrimitiveType_Name(start_indices->shape().element_type());
}
return Status::OK();
};
Status HandleReduce(HloInstruction* reduce, HloInstruction* arg,
HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions,
HloComputation* function) override {
TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) ==
ShapeUtil::Rank(arg->shape()) - dimensions.size());
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
ShapeInference::InferReduceShape(
/*arg=*/arg->shape(),
/*init_value=*/init_value->shape(),
/*dimensions_to_reduce=*/dimensions,
/*to_apply=*/function->ComputeProgramShape()));
TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape))
<< "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
<< "but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg);
VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString();
const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value);
VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString();
TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
auto init_scalar = init_literal.Get<ReturnT>({});
auto result = Literal::CreateFromShape(reduce->shape());
const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions());
std::vector<int64> arg_dim_steps(arg_dimensions.size());
std::vector<int64> arg_dim_counts(arg_dimensions.size());
for (const int64 dim : dimensions) {
arg_dim_steps[dim] = 1;
arg_dim_counts[dim] = arg_dimensions[dim];
}
// Create mapping from result index to arg index.
const int64 result_rank = ShapeUtil::Rank(result->shape());
int64 result_dim = 0;
std::vector<int64> result_to_arg_index(result_rank);
for (int64 i = 0; i < arg_dimensions.size(); ++i) {
if (arg_dim_steps[i] == 0) {
result_to_arg_index[result_dim] = i;
++result_dim;
}
}
// For each resulting dimension, calculate and assign computed value.
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
ReturnT result_val = init_scalar;
std::vector<int64> base(arg_dimensions.size());
for (int64 i = 0; i < multi_index.size(); ++i) {
base[result_to_arg_index[i]] = multi_index[i];
}
auto func = [&](const std::vector<int64>& input_index) {
auto curr_val = arg_literal.Get<ReturnT>(input_index);
// Evaluate computation with specified literal operands.
auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
auto result_val_literal = Literal::CreateR0<ReturnT>(result_val);
std::vector<const Literal*> args = {curr_val_literal.get(),
result_val_literal.get()};
// We need a new visitor for each evaluation, so that the same
// computation can be visited more than once (with different
// inputs).
HloEvaluator embedded_evaluator;
std::unique_ptr<Literal> computed_result =
embedded_evaluator.Evaluate(*function, args)
.ConsumeValueOrDie();
// Assign computed result to result_val.
result_val = computed_result->Get<ReturnT>({});
return true;
};
ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts,
arg_dim_steps, func);
return result_val;
}));
parent_->evaluated_[reduce] = std::move(result);
return Status::OK();
};
Status Preprocess(HloInstruction* hlo) override {
VLOG(2) << hlo->ToString();
return Status::OK();
};
private:
template <typename IndexT>
StatusOr<std::unique_ptr<Literal>> DynamicSlice(
const Literal& operand_literal, const Literal& start_indices_literal,
const Shape& result_shape) {
const auto& start_indices_typed =
start_indices_literal.GetArraySlice<IndexT>();
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
std::vector<int64> operand_indices(start.size(), 0);
auto result = Literal::CreateFromShape(result_shape);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> multi_index) {
std::transform(multi_index.begin(), multi_index.end(), start.begin(),
operand_indices.begin(), std::plus<int64>());
return operand_literal.Get<ReturnT>(operand_indices);
}));
return std::move(result);
}
template <typename IndexT>
StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
const Literal& operand_literal, const Literal& update_literal,
const Literal& start_indices_literal) {
const auto& start_indices_typed =
start_indices_literal.GetArraySlice<IndexT>();
const std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
auto result = MakeUnique<Literal>(operand_literal);
std::vector<int64> result_index(ShapeUtil::Rank(result->shape()), 0);
auto func = [&](const std::vector<int64>& update_index) {
std::transform(update_index.begin(), update_index.end(), start.begin(),
result_index.begin(), std::plus<int64>());
result->Set<ReturnT>(result_index,
update_literal.Get<ReturnT>(update_index));
return true;
};
std::vector<int64> base(update_literal.shape().dimensions_size(), 0);
std::vector<int64> step(update_literal.shape().dimensions_size(), 1);
ShapeUtil::ForEachIndex(update_literal.shape(), base,
AsInt64Slice(update_literal.shape().dimensions()),
step, func);
return std::move(result);
}
StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
HloInstruction* instruction,
const std::function<ReturnT(ReturnT)>& unary_op) {
@ -771,14 +1021,28 @@ HloEvaluator::HloEvaluator() {
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
HloComputation* computation,
tensorflow::gtl::ArraySlice<const Literal*> args) {
arg_literals_ = args;
const HloModule& module,
tensorflow::gtl::ArraySlice<const Literal*> arg_literals) {
XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString());
arg_literals_ = arg_literals;
evaluated_.clear();
TF_RETURN_IF_ERROR(computation->Accept(this));
TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this));
return MakeUnique<Literal>(
GetEvaluatedLiteralFor(computation->root_instruction()));
GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()));
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
const HloComputation& computation,
tensorflow::gtl::ArraySlice<const Literal*> arg_literals) {
arg_literals_ = arg_literals;
evaluated_.clear();
TF_RETURN_IF_ERROR(computation.Accept(this));
return MakeUnique<Literal>(
GetEvaluatedLiteralFor(computation.root_instruction()));
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
@ -930,7 +1194,8 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite,
break;
}
default:
LOG(FATAL) << "unknown/unhandled primitive type.";
LOG(FATAL) << "HandleIsFinite: unknown/unhandled primitive type: "
<< PrimitiveType_Name(operand->shape().element_type());
}
return Status::OK();
@ -1009,7 +1274,8 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode,
Compare<double>(compare->shape(), opcode, lhs_literal, rhs_literal));
} break;
default:
LOG(FATAL) << "unknown primitive type.";
LOG(FATAL) << "HandleCompare: unknown primitive type: "
<< PrimitiveType_Name(lhs->shape().element_type());
}
return Status::OK();

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -36,9 +37,17 @@ namespace xla {
class HloEvaluator : public DfsHloVisitorWithDefault {
public:
HloEvaluator();
// Evaluates a HLO computation and an array of pointers to literals.
// Return the evaluated result as literal if successful.
// Precondition: argument literals are corresponds to the input computation's
// Evaluates an HLO module and an array of pointers to literals.
// Returns the evaluated result as a literal if successful.
// Precondition: argument literals correspond to each input computation's
// parameters in their post-ordering. See comment below for example.
StatusOr<std::unique_ptr<Literal>> Evaluate(
const HloModule& module,
tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
// Evaluates an HLO computation and an array of pointers to literals.
// Returns the evaluated result as a literal if successful.
// Precondition: argument literals correspond to the input computation's
// parameters in their post-ordering. For e.g., consider the following graph:
//
// *
@ -51,7 +60,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// The input literals array will have its first literal map to Parameter0 and
// the second map to Parameter1.
StatusOr<std::unique_ptr<Literal>> Evaluate(
HloComputation* computation,
const HloComputation& computation,
tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
// Evaluates a single HLO instruction and an array of pointers to literals.

View File

@ -188,7 +188,7 @@ TEST_F(HloEvaluatorTest, DoesAbs) {
// Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
// constant operands.
TEST_F(HloEvaluatorTest, DoesTraverseInstructions) {
HloComputation::Builder builder(TestName());
HloComputation::Builder b(TestName());
auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
auto rhs2 = Literal::CreateR2<int64>({{1, -20}, {-100, 4}});
@ -205,9 +205,9 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) {
auto root_instruction = HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, lhs_instruction.get(), param_rhs2.get());
builder.AddInstruction(std::move(root_instruction));
b.AddInstruction(std::move(root_instruction));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(builder.Build().get(), args).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), args).ConsumeValueOrDie();
auto expected = Literal::CreateR2<int64>({{4, -16}, {-196, 12}});
@ -216,22 +216,22 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) {
// Verifies Reshape operation is correctly evaluated.
TEST_F(HloEvaluatorTest, DoesReshape) {
HloComputation::Builder builder(TestName());
HloComputation::Builder b(TestName());
const int64 dimensions[] = {11, 8, 7, 5, 9};
TF_ASSERT_OK_AND_ASSIGN(auto literal,
LiteralTestUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
auto literal_clone = literal->CloneToUnique();
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
HloInstruction* literal_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
const int64 permutation[] = {1, 2, 0, 4, 3};
builder.AddInstruction(
b.AddInstruction(
HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
result->EachCell<NativeT>(
@ -243,24 +243,24 @@ TEST_F(HloEvaluatorTest, DoesReshape) {
// Verifies Broadcast operation is correctly evaluated.
TEST_F(HloEvaluatorTest, DoesBroadcast) {
HloComputation::Builder builder(TestName());
HloComputation::Builder b(TestName());
auto input_literal = Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
auto output_literal = Literal::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}});
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction* literal_instruction = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
builder.AddInstruction(HloInstruction::CreateBroadcast(
b.AddInstruction(HloInstruction::CreateBroadcast(
output_literal->shape(), literal_instruction, {1, 2}));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
LiteralTestUtil::ExpectEqual(*result, *output_literal);
}
TEST_F(HloEvaluatorTest, ConvertWithSameLayout) {
HloComputation::Builder builder(TestName());
HloComputation::Builder b(TestName());
auto input_literal = Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
auto expected =
@ -268,19 +268,18 @@ TEST_F(HloEvaluatorTest, ConvertWithSameLayout) {
ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
expected->shape()));
HloInstruction* constant = builder.AddInstruction(
HloInstruction* constant = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
builder.AddInstruction(
HloInstruction::CreateConvert(expected->shape(), constant));
b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
LiteralTestUtil::ExpectEqual(*result, *expected);
}
TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) {
HloComputation::Builder builder(TestName());
HloComputation::Builder b(TestName());
auto input_literal = Literal::CreateR2WithLayout<int32>(
{{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
@ -289,13 +288,12 @@ TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) {
ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
expected->shape()));
HloInstruction* constant = builder.AddInstruction(
HloInstruction* constant = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
builder.AddInstruction(
HloInstruction::CreateConvert(expected->shape(), constant));
b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
LiteralTestUtil::ExpectEqual(*result, *expected);
}
@ -355,7 +353,7 @@ TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
auto expected_array = MakeUnique<Array4D<float>>(8, 5, 1, 1);
expected_array->Fill(kPadValue);
@ -398,7 +396,7 @@ TEST_F(HloEvaluatorTest, NegativePadding2D) {
r2_padding_on_dim0_dim1));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
// f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
auto expected_array = MakeUnique<Array2D<float>>(1, 5);
@ -442,7 +440,7 @@ TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
r2_padding_on_dim0_dim1));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
auto expected_array = MakeUnique<Array2D<float>>(0, 9);
auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
@ -477,7 +475,7 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank1) {
shape, HloOpcode::kDot, lhs_instruction, rhs_instruction));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
// clang-format off
auto expected_array = Array2D<float>({
@ -519,7 +517,7 @@ TEST_F(HloEvaluatorTest, DotRank1AndRank2) {
shape, HloOpcode::kDot, lhs_instruction, rhs_instruction));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
auto expected = Literal::CreateR1<float>({22.f, 28.f});
@ -559,10 +557,13 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank2) {
shape, HloOpcode::kDot, lhs_instruction, rhs_instruction));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
auto expected_array = Array2D<float>({
{22.f, 28.f}, {58.f, 76.f}, {94.f, 124.f}, {130.f, 172.f},
{22.f, 28.f},
{58.f, 76.f},
{94.f, 124.f},
{130.f, 172.f},
});
auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
@ -606,7 +607,7 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) {
shape, lhs_instruction, rhs_instruction, window, dnums));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
auto expected = Literal::CreateR3FromArray3D<float>(expected_array);
@ -660,7 +661,7 @@ TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
shape, lhs_instruction, rhs_instruction, window, dnums));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
Array4D<float> expected_array(1, 1, 4, 4);
// clang-format off
@ -736,7 +737,7 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) {
shape, lhs_instruction, rhs_instruction, window, dnums));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
// clang-format off
// Result dimensions: [feature=1, height=1, batch=1, width=2]
@ -793,7 +794,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
shape, lhs_instruction, rhs_instruction, window, dnums));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
Array4D<float> expected_array(1, 1, 7, 7);
expected_array.FillWithYX(Array2D<float>({
@ -856,7 +857,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
shape, lhs_instruction, rhs_instruction, window, dnums));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
Array4D<float> expected_array(1, 1, 8, 8);
expected_array.FillWithYX(Array2D<float>({
@ -927,7 +928,7 @@ TEST_F(HloEvaluatorTest,
shape, lhs_instruction, rhs_instruction, window, dnums));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(b.Build().get(), {}).ConsumeValueOrDie();
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
Array4D<float> expected_array(1, 1, 9, 3);
expected_array.FillWithYX(Array2D<float>({
@ -946,5 +947,115 @@ TEST_F(HloEvaluatorTest,
LiteralTestUtil::ExpectEqual(*expected, *result);
}
TEST_F(HloEvaluatorTest, ReduceAdd) {
HloComputation::Builder b(TestName());
// arg:
// f32[2,3] {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
auto arg_array = MakeUnique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
auto arg_literal = Literal::CreateR2FromArray2D<float>(*arg_array);
HloInstruction* arg_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
auto init_value = b.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
HloComputation::Builder add_computation("add");
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
auto param_lhs = add_computation.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
auto param_rhs = add_computation.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
add_computation.AddInstruction(HloInstruction::CreateBinary(
scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
auto add_func = add_computation.Build();
Shape shape = ShapeUtil::MakeShape(F32, {2});
b.AddInstruction(HloInstruction::CreateReduce(
shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{1},
add_func.get()));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
auto expected = Literal::CreateR1<float>({6, 18});
LiteralTestUtil::ExpectEqual(*expected, *result);
}
TEST_F(HloEvaluatorTest, DynamicSlice) {
HloComputation::Builder b(TestName());
// arg:
// f32[2,4] {
// { 1, 2, 3, 4 },
// { 5, 6, 7, 8 },
// }
auto operand_array = MakeUnique<Array2D<float>>(2, 4);
operand_array->FillUnique(1.0f);
auto operand_literal = Literal::CreateR2FromArray2D<float>(*operand_array);
HloInstruction* operand = b.AddInstruction(
HloInstruction::CreateConstant(std::move(operand_literal)));
auto start_indices = b.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR1<int32>({0, 1})));
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand,
start_indices, {2, 3}));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
auto expected = Literal::CreateR2<float>({
{2, 3, 4},
{6, 7, 8},
});
LiteralTestUtil::ExpectEqual(*expected, *result);
}
TEST_F(HloEvaluatorTest, DynamicSliceUpdate) {
HloComputation::Builder b(TestName());
// arg:
// f32[2,3] {
// { 1, 2, 3 },
// { 5, 6, 7 },
// }
auto operand_array = MakeUnique<Array2D<double>>(2, 3);
operand_array->FillUnique(1.0);
auto operand_literal = Literal::CreateR2FromArray2D<double>(*operand_array);
HloInstruction* operand = b.AddInstruction(
HloInstruction::CreateConstant(std::move(operand_literal)));
auto start_indices = b.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR1<int64>({0, 1})));
auto update = b.AddInstruction(HloInstruction::CreateConstant(
Literal::CreateR2<double>({{-2.0, -3.0}, {-6.0, -7.0}})));
Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
shape, operand, update, start_indices));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
auto expected = Literal::CreateR2<double>({
{1, -2, -3},
{5, -6, -7},
});
LiteralTestUtil::ExpectEqual(*expected, *result);
}
} // namespace
} // namespace xla