mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
[XLA] Handle Reverse in HloEvaluator.
Also move HandleCopy to outer visitor instead, since it can be implemented as a type-agnostic copy instead. PiperOrigin-RevId: 163866499
This commit is contained in:
parent
96675956ef
commit
e62de3f784
|
|
@ -194,14 +194,6 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
};
|
||||||
|
|
||||||
Status HandleCopy(HloInstruction* copy) override {
|
|
||||||
TF_ASSIGN_OR_RETURN(parent_->evaluated_[copy],
|
|
||||||
ElementWiseUnaryOp(copy, [](ReturnT elem_operand) {
|
|
||||||
return elem_operand;
|
|
||||||
}));
|
|
||||||
return Status::OK();
|
|
||||||
};
|
|
||||||
|
|
||||||
Status HandleConvert(HloInstruction* convert) override {
|
Status HandleConvert(HloInstruction* convert) override {
|
||||||
const HloInstruction* operand = convert->operand(0);
|
const HloInstruction* operand = convert->operand(0);
|
||||||
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
|
TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
|
||||||
|
|
@ -402,6 +394,36 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Status HandleReverse(HloInstruction* reverse,
|
||||||
|
HloInstruction* operand) override {
|
||||||
|
const auto result_shape = reverse->shape();
|
||||||
|
const auto reverse_dimensions = reverse->dimensions();
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
|
||||||
|
ShapeInference::InferReverseShape(operand->shape(),
|
||||||
|
reverse_dimensions));
|
||||||
|
|
||||||
|
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
|
||||||
|
<< "return shape set to: " << ShapeUtil::HumanString(result_shape)
|
||||||
|
<< " but is inferred to be: "
|
||||||
|
<< ShapeUtil::HumanString(inferred_return_shape);
|
||||||
|
|
||||||
|
auto operand_literal = parent_->GetEvaluatedLiteralFor(operand);
|
||||||
|
auto result = Literal::CreateFromShape(result_shape);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
|
||||||
|
[&](tensorflow::gtl::ArraySlice<int64> out_index) {
|
||||||
|
std::vector<int64> from_index(out_index.begin(), out_index.end());
|
||||||
|
for (const int64 dim : reverse_dimensions) {
|
||||||
|
from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
|
||||||
|
}
|
||||||
|
return operand_literal.Get<ReturnT>(from_index);
|
||||||
|
}));
|
||||||
|
|
||||||
|
parent_->evaluated_[reverse] = std::move(result);
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs,
|
Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs,
|
||||||
HloInstruction* rhs, const Window& window) override {
|
HloInstruction* rhs, const Window& window) override {
|
||||||
const Shape& result_shape = conv->shape();
|
const Shape& result_shape = conv->shape();
|
||||||
|
|
@ -1301,4 +1323,12 @@ Status HloEvaluator::HandleSlice(HloInstruction* slice,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status HloEvaluator::HandleCopy(HloInstruction* copy) {
|
||||||
|
TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
|
||||||
|
|
||||||
|
auto result = MakeUnique<Literal>(GetEvaluatedLiteralFor(copy->operand(0)));
|
||||||
|
evaluated_[copy] = std::move(result);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
||||||
|
|
@ -125,6 +125,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||||
Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
|
Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
|
||||||
HloInstruction* lhs, HloInstruction* rhs) override;
|
HloInstruction* lhs, HloInstruction* rhs) override;
|
||||||
|
|
||||||
|
Status HandleCopy(HloInstruction* copy) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Returns the already-evaluated literal result for the instruction.
|
// Returns the already-evaluated literal result for the instruction.
|
||||||
// A Constant instruction is considered evaluated and its literal will be
|
// A Constant instruction is considered evaluated and its literal will be
|
||||||
|
|
|
||||||
|
|
@ -1057,5 +1057,58 @@ TEST_F(HloEvaluatorTest, DynamicSliceUpdate) {
|
||||||
LiteralTestUtil::ExpectEqual(*expected, *result);
|
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(HloEvaluatorTest, Reverse) {
|
||||||
|
HloComputation::Builder b(TestName());
|
||||||
|
|
||||||
|
// Input shape is float[4x3x2x1].
|
||||||
|
// clang-format off
|
||||||
|
Array4D<float> input({
|
||||||
|
{{{1.0f}, {2.0f}},
|
||||||
|
{{3.0f}, {4.0f}},
|
||||||
|
{{5.0f}, {6.0f}}},
|
||||||
|
{{{7.0f}, {8.0f}},
|
||||||
|
{{9.0f}, {10.0f}},
|
||||||
|
{{11.0f}, {12.0f}}},
|
||||||
|
{{{13.0f}, {14.0f}},
|
||||||
|
{{15.0f}, {16.0f}},
|
||||||
|
{{17.0f}, {18.0f}}},
|
||||||
|
{{{19.0f}, {20.0f}},
|
||||||
|
{{21.0f}, {22.0f}},
|
||||||
|
{{23.0f}, {24.0f}}},
|
||||||
|
});
|
||||||
|
// clang-format on
|
||||||
|
auto operand_literal = Literal::CreateR4FromArray4D<float>(input);
|
||||||
|
HloInstruction* operand = b.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(std::move(operand_literal)));
|
||||||
|
|
||||||
|
const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1});
|
||||||
|
b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
|
||||||
|
|
||||||
|
std::unique_ptr<Literal> result =
|
||||||
|
evaluator_->Evaluate(*b.Build(), {}).ConsumeValueOrDie();
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
auto expected = Literal::CreateR4FromArray4D<float>({
|
||||||
|
{{{23.0f}, {24.0f}},
|
||||||
|
{{21.0f}, {22.0f}},
|
||||||
|
{{19.0f}, {20.0f}}},
|
||||||
|
|
||||||
|
{{{17.0f}, {18.0f}},
|
||||||
|
{{15.0f}, {16.0f}},
|
||||||
|
{{13.0f}, {14.0f}}},
|
||||||
|
|
||||||
|
{{{11.0f}, {12.0f}},
|
||||||
|
{{9.0f}, {10.0f}},
|
||||||
|
{{7.0f}, {8.0f}}},
|
||||||
|
|
||||||
|
{{{5.0f}, {6.0f}},
|
||||||
|
{{3.0f}, {4.0f}},
|
||||||
|
{{1.0f}, {2.0f}}},
|
||||||
|
});
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
LiteralTestUtil::ExpectEqual(*expected, *result);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user