[XLA] Handle Reverse in HloEvaluator.

Also move HandleCopy to outer visitor instead, since it can be implemented
as a type-agnostic copy instead.

PiperOrigin-RevId: 163866499
This commit is contained in:
Kay Zhu 2017-08-01 12:17:13 -07:00 committed by TensorFlower Gardener
parent 96675956ef
commit e62de3f784
3 changed files with 93 additions and 8 deletions

View File

@ -194,14 +194,6 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK(); return Status::OK();
}; };
Status HandleCopy(HloInstruction* copy) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[copy],
ElementWiseUnaryOp(copy, [](ReturnT elem_operand) {
return elem_operand;
}));
return Status::OK();
};
Status HandleConvert(HloInstruction* convert) override { Status HandleConvert(HloInstruction* convert) override {
const HloInstruction* operand = convert->operand(0); const HloInstruction* operand = convert->operand(0);
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
@ -402,6 +394,36 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK(); return Status::OK();
}; };
Status HandleReverse(HloInstruction* reverse,
HloInstruction* operand) override {
const auto result_shape = reverse->shape();
const auto reverse_dimensions = reverse->dimensions();
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
ShapeInference::InferReverseShape(operand->shape(),
reverse_dimensions));
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape set to: " << ShapeUtil::HumanString(result_shape)
<< " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
auto operand_literal = parent_->GetEvaluatedLiteralFor(operand);
auto result = Literal::CreateFromShape(result_shape);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> out_index) {
std::vector<int64> from_index(out_index.begin(), out_index.end());
for (const int64 dim : reverse_dimensions) {
from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
}
return operand_literal.Get<ReturnT>(from_index);
}));
parent_->evaluated_[reverse] = std::move(result);
return Status::OK();
};
Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs, Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs,
HloInstruction* rhs, const Window& window) override { HloInstruction* rhs, const Window& window) override {
const Shape& result_shape = conv->shape(); const Shape& result_shape = conv->shape();
@ -1301,4 +1323,12 @@ Status HloEvaluator::HandleSlice(HloInstruction* slice,
return Status::OK(); return Status::OK();
} }
Status HloEvaluator::HandleCopy(HloInstruction* copy) {
TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
auto result = MakeUnique<Literal>(GetEvaluatedLiteralFor(copy->operand(0)));
evaluated_[copy] = std::move(result);
return Status::OK();
}
} // namespace xla } // namespace xla

View File

@ -125,6 +125,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleCompare(HloInstruction* compare, HloOpcode opcode, Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
HloInstruction* lhs, HloInstruction* rhs) override; HloInstruction* lhs, HloInstruction* rhs) override;
Status HandleCopy(HloInstruction* copy) override;
private: private:
// Returns the already-evaluated literal result for the instruction. // Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be // A Constant instruction is considered evaluated and its literal will be

View File

@ -1057,5 +1057,58 @@ TEST_F(HloEvaluatorTest, DynamicSliceUpdate) {
LiteralTestUtil::ExpectEqual(*expected, *result); LiteralTestUtil::ExpectEqual(*expected, *result);
} }
TEST_F(HloEvaluatorTest, Reverse) {
HloComputation::Builder b(TestName());
// Input shape is float[4x3x2x1].
// clang-format off
Array4D<float> input({
{{{1.0f}, {2.0f}},
{{3.0f}, {4.0f}},
{{5.0f}, {6.0f}}},
{{{7.0f}, {8.0f}},
{{9.0f}, {10.0f}},
{{11.0f}, {12.0f}}},
{{{13.0f}, {14.0f}},
{{15.0f}, {16.0f}},
{{17.0f}, {18.0f}}},
{{{19.0f}, {20.0f}},
{{21.0f}, {22.0f}},
{{23.0f}, {24.0f}}},
});
// clang-format on
auto operand_literal = Literal::CreateR4FromArray4D<float>(input);
HloInstruction* operand = b.AddInstruction(
HloInstruction::CreateConstant(std::move(operand_literal)));
const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1});
b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
std::unique_ptr<Literal> result =
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
// clang-format off
auto expected = Literal::CreateR4FromArray4D<float>({
{{{23.0f}, {24.0f}},
{{21.0f}, {22.0f}},
{{19.0f}, {20.0f}}},
{{{17.0f}, {18.0f}},
{{15.0f}, {16.0f}},
{{13.0f}, {14.0f}}},
{{{11.0f}, {12.0f}},
{{9.0f}, {10.0f}},
{{7.0f}, {8.0f}}},
{{{5.0f}, {6.0f}},
{{3.0f}, {4.0f}},
{{1.0f}, {2.0f}}},
});
// clang-format on
LiteralTestUtil::ExpectEqual(*expected, *result);
}
} // namespace } // namespace
} // namespace xla } // namespace xla