diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 55f5504de42..6e8df608502 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -194,14 +194,6 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleCopy(HloInstruction* copy) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[copy], - ElementWiseUnaryOp(copy, [](ReturnT elem_operand) { - return elem_operand; - })); - return Status::OK(); - }; - Status HandleConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); @@ -402,6 +394,36 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; + Status HandleReverse(HloInstruction* reverse, + HloInstruction* operand) override { + const auto result_shape = reverse->shape(); + const auto reverse_dimensions = reverse->dimensions(); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferReverseShape(operand->shape(), + reverse_dimensions)); + + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(result_shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + auto operand_literal = parent_->GetEvaluatedLiteralFor(operand); + auto result = Literal::CreateFromShape(result_shape); + + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice out_index) { + std::vector from_index(out_index.begin(), out_index.end()); + for (const int64 dim : reverse_dimensions) { + from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; + } + return operand_literal.Get(from_index); + })); + + parent_->evaluated_[reverse] = std::move(result); + return Status::OK(); + }; + Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs, HloInstruction* rhs, const Window& window) override { const Shape& result_shape = conv->shape(); @@ -1301,4 +1323,12 @@ Status HloEvaluator::HandleSlice(HloInstruction* slice, return Status::OK(); } +Status HloEvaluator::HandleCopy(HloInstruction* copy) { + TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape())); + + auto result = MakeUnique(GetEvaluatedLiteralFor(copy->operand(0))); + evaluated_[copy] = std::move(result); + return Status::OK(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index fbb385c40fa..920c4901daf 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -125,6 +125,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleCompare(HloInstruction* compare, HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs) override; + Status HandleCopy(HloInstruction* copy) override; + private: // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index d2770f6a612..088b76b62f4 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -1057,5 +1057,58 @@ TEST_F(HloEvaluatorTest, DynamicSliceUpdate) { LiteralTestUtil::ExpectEqual(*expected, *result); } +TEST_F(HloEvaluatorTest, Reverse) { + HloComputation::Builder b(TestName()); + + // Input shape is float[4x3x2x1]. + // clang-format off + Array4D input({ + {{{1.0f}, {2.0f}}, + {{3.0f}, {4.0f}}, + {{5.0f}, {6.0f}}}, + {{{7.0f}, {8.0f}}, + {{9.0f}, {10.0f}}, + {{11.0f}, {12.0f}}}, + {{{13.0f}, {14.0f}}, + {{15.0f}, {16.0f}}, + {{17.0f}, {18.0f}}}, + {{{19.0f}, {20.0f}}, + {{21.0f}, {22.0f}}, + {{23.0f}, {24.0f}}}, + }); + // clang-format on + auto operand_literal = Literal::CreateR4FromArray4D(input); + HloInstruction* operand = b.AddInstruction( + HloInstruction::CreateConstant(std::move(operand_literal))); + + const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1}); + b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); + + std::unique_ptr result = + evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie(); + + // clang-format off + auto expected = Literal::CreateR4FromArray4D({ + {{{23.0f}, {24.0f}}, + {{21.0f}, {22.0f}}, + {{19.0f}, {20.0f}}}, + + {{{17.0f}, {18.0f}}, + {{15.0f}, {16.0f}}, + {{13.0f}, {14.0f}}}, + + {{{11.0f}, {12.0f}}, + {{9.0f}, {10.0f}}, + {{7.0f}, {8.0f}}}, + + {{{5.0f}, {6.0f}}, + {{3.0f}, {4.0f}}, + {{1.0f}, {2.0f}}}, + }); + // clang-format on + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + } // namespace } // namespace xla