mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[XLA] Handle Reverse in HloEvaluator.
Also move HandleCopy to outer visitor instead, since it can be implemented as a type-agnostic copy instead. PiperOrigin-RevId: 163866499
This commit is contained in:
parent
96675956ef
commit
e62de3f784
|
|
@ -194,14 +194,6 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||
return Status::OK();
|
||||
};
|
||||
|
||||
Status HandleCopy(HloInstruction* copy) override {
|
||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[copy],
|
||||
ElementWiseUnaryOp(copy, [](ReturnT elem_operand) {
|
||||
return elem_operand;
|
||||
}));
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
Status HandleConvert(HloInstruction* convert) override {
|
||||
const HloInstruction* operand = convert->operand(0);
|
||||
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
|
||||
|
|
@ -402,6 +394,36 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
|||
return Status::OK();
|
||||
};
|
||||
|
||||
Status HandleReverse(HloInstruction* reverse,
|
||||
HloInstruction* operand) override {
|
||||
const auto result_shape = reverse->shape();
|
||||
const auto reverse_dimensions = reverse->dimensions();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
|
||||
ShapeInference::InferReverseShape(operand->shape(),
|
||||
reverse_dimensions));
|
||||
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
|
||||
<< "return shape set to: " << ShapeUtil::HumanString(result_shape)
|
||||
<< " but is inferred to be: "
|
||||
<< ShapeUtil::HumanString(inferred_return_shape);
|
||||
|
||||
auto operand_literal = parent_->GetEvaluatedLiteralFor(operand);
|
||||
auto result = Literal::CreateFromShape(result_shape);
|
||||
|
||||
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
||||
[&](tensorflow::gtl::ArraySlice<int64> out_index) {
|
||||
std::vector<int64> from_index(out_index.begin(), out_index.end());
|
||||
for (const int64 dim : reverse_dimensions) {
|
||||
from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
|
||||
}
|
||||
return operand_literal.Get<ReturnT>(from_index);
|
||||
}));
|
||||
|
||||
parent_->evaluated_[reverse] = std::move(result);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs,
|
||||
HloInstruction* rhs, const Window& window) override {
|
||||
const Shape& result_shape = conv->shape();
|
||||
|
|
@ -1301,4 +1323,12 @@ Status HloEvaluator::HandleSlice(HloInstruction* slice,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloEvaluator::HandleCopy(HloInstruction* copy) {
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
|
||||
|
||||
auto result = MakeUnique<Literal>(GetEvaluatedLiteralFor(copy->operand(0)));
|
||||
evaluated_[copy] = std::move(result);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
|
|
|||
|
|
@ -125,6 +125,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
|||
Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
|
||||
HloInstruction* lhs, HloInstruction* rhs) override;
|
||||
|
||||
Status HandleCopy(HloInstruction* copy) override;
|
||||
|
||||
private:
|
||||
// Returns the already-evaluated literal result for the instruction.
|
||||
// A Constant instruction is considered evaluated and its literal will be
|
||||
|
|
|
|||
|
|
@ -1057,5 +1057,58 @@ TEST_F(HloEvaluatorTest, DynamicSliceUpdate) {
|
|||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
}
|
||||
|
||||
TEST_F(HloEvaluatorTest, Reverse) {
|
||||
HloComputation::Builder b(TestName());
|
||||
|
||||
// Input shape is float[4x3x2x1].
|
||||
// clang-format off
|
||||
Array4D<float> input({
|
||||
{{{1.0f}, {2.0f}},
|
||||
{{3.0f}, {4.0f}},
|
||||
{{5.0f}, {6.0f}}},
|
||||
{{{7.0f}, {8.0f}},
|
||||
{{9.0f}, {10.0f}},
|
||||
{{11.0f}, {12.0f}}},
|
||||
{{{13.0f}, {14.0f}},
|
||||
{{15.0f}, {16.0f}},
|
||||
{{17.0f}, {18.0f}}},
|
||||
{{{19.0f}, {20.0f}},
|
||||
{{21.0f}, {22.0f}},
|
||||
{{23.0f}, {24.0f}}},
|
||||
});
|
||||
// clang-format on
|
||||
auto operand_literal = Literal::CreateR4FromArray4D<float>(input);
|
||||
HloInstruction* operand = b.AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(operand_literal)));
|
||||
|
||||
const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1});
|
||||
b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
|
||||
|
||||
std::unique_ptr<Literal> result =
|
||||
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||
|
||||
// clang-format off
|
||||
auto expected = Literal::CreateR4FromArray4D<float>({
|
||||
{{{23.0f}, {24.0f}},
|
||||
{{21.0f}, {22.0f}},
|
||||
{{19.0f}, {20.0f}}},
|
||||
|
||||
{{{17.0f}, {18.0f}},
|
||||
{{15.0f}, {16.0f}},
|
||||
{{13.0f}, {14.0f}}},
|
||||
|
||||
{{{11.0f}, {12.0f}},
|
||||
{{9.0f}, {10.0f}},
|
||||
{{7.0f}, {8.0f}}},
|
||||
|
||||
{{{5.0f}, {6.0f}},
|
||||
{{3.0f}, {4.0f}},
|
||||
{{1.0f}, {2.0f}}},
|
||||
});
|
||||
// clang-format on
|
||||
|
||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user