[XLA] Add HLO matchers that check parameter numbers and GTE indices.

This lets you do

  EXPECT_THAT(foo, op::Parameter(42));

and

  EXPECT_THAT(bar, op::GetTupleElement(baz, 8));

PiperOrigin-RevId: 174113597
This commit is contained in:
Justin Lebar 2017-10-31 16:47:47 -07:00 committed by TensorFlower Gardener
parent f97e7c69b8
commit 4aa90bfd39
2 changed files with 96 additions and 2 deletions

View File

@ -73,6 +73,35 @@ void HloMatcher::DescribeTo(::std::ostream* os) const {
}
}
bool HloParameterMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
if (!HloMatcher::MatchAndExplain(instruction, listener)) {
return false;
}
if (instruction->parameter_number() != parameter_number_) {
*listener << "has wrong parameter number (got "
<< instruction->parameter_number() << ", want "
<< parameter_number_ << ")";
return false;
}
return true;
}
bool HloGetTupleElementMatcher::MatchAndExplain(
const HloInstruction* instruction,
::testing::MatchResultListener* listener) const {
if (!HloMatcher::MatchAndExplain(instruction, listener)) {
return false;
}
if (instruction->tuple_index() != tuple_index_) {
*listener << "has wrong tuple index (got " << instruction->tuple_index()
<< ", want " << tuple_index_ << ")";
return false;
}
return true;
}
} // namespace testing
void PrintTo(const HloInstruction* inst, ::std::ostream* os) {

View File

@ -38,6 +38,36 @@ class HloMatcher : public ::testing::MatcherInterface<const HloInstruction*> {
std::vector<::testing::Matcher<const HloInstruction*>> operands_;
};
// Custom matcher for parameters, which accepts a parameter number.
class HloParameterMatcher : public HloMatcher {
public:
explicit HloParameterMatcher(int64 parameter_number)
: HloMatcher(HloOpcode::kParameter, /*operands=*/{}),
parameter_number_(parameter_number) {}
bool MatchAndExplain(const HloInstruction* instruction,
::testing::MatchResultListener* listener) const override;
private:
int64 parameter_number_;
};
// Custom matcher for get-tuple-element instructions, which accepts a tuple
// index to match.
class HloGetTupleElementMatcher : public HloMatcher {
public:
explicit HloGetTupleElementMatcher(
::testing::Matcher<const HloInstruction*> operand, int64 tuple_index)
: HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}),
tuple_index_(tuple_index) {}
bool MatchAndExplain(const HloInstruction* instruction,
::testing::MatchResultListener* listener) const override;
private:
int64 tuple_index_;
};
// HloInstruction* matchers for opcode and operands. Example:
// namespace op = xla::opcode_matchers;
// EXPECT_THAT(instruction,
@ -72,7 +102,6 @@ HLO_MATCHER(Exp);
HLO_MATCHER(Floor);
HLO_MATCHER(Fusion);
HLO_MATCHER(Ge);
HLO_MATCHER(GetTupleElement);
HLO_MATCHER(Gt);
HLO_MATCHER(Infeed);
HLO_MATCHER(IsFinite);
@ -90,7 +119,6 @@ HLO_MATCHER(Ne);
HLO_MATCHER(Negate);
HLO_MATCHER(Outfeed);
HLO_MATCHER(Pad);
HLO_MATCHER(Parameter);
HLO_MATCHER(Power);
HLO_MATCHER(Recv);
HLO_MATCHER(Reduce);
@ -115,6 +143,43 @@ HLO_MATCHER(Trace);
HLO_MATCHER(Transpose);
HLO_MATCHER(Tuple);
HLO_MATCHER(While);
// The special cases below let you check additional information about the
// HloInstruction, beyond just its opcode and operands. In all cases you can
// still use the generic matcher which doesn't check this info.
//
// Feel free to add additional custom matchers below.
// - Parameter(N) matches parameter number N.
// - Parameter() matches any parameter.
inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter(
int64 parameter_number) {
return ::testing::MakeMatcher(
new ::xla::testing::HloParameterMatcher(parameter_number));
}
inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter() {
return ::testing::MakeMatcher(
new ::xla::testing::HloMatcher(HloOpcode::kParameter, {}));
}
// GetTupleElement(operand, N) matches a GTE instruction which gets the N'th
// tuple element of operand, while GetTupleElement(operand) matches any GTE
// operation on operand, and GetTupleElement() matches any GTE operation at all.
inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
::testing::Matcher<const HloInstruction*> operand, int64 tuple_index) {
return ::testing::MakeMatcher(
new ::xla::testing::HloGetTupleElementMatcher(operand, tuple_index));
}
inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
::testing::Matcher<const HloInstruction*> operand) {
return ::testing::MakeMatcher(
new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {operand}));
}
inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement() {
return ::testing::MakeMatcher(
new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {}));
}
#undef HLO_MATCHER
} // namespace opcode_matchers