mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA] Add HLO matchers that check parameter numbers and GTE indices.
This lets you do EXPECT_THAT(foo, op::Parameter(42)); and EXPECT_THAT(bar, op::GetTupleElement(baz, 8)); PiperOrigin-RevId: 174113597
This commit is contained in:
parent
f97e7c69b8
commit
4aa90bfd39
|
|
@ -73,6 +73,35 @@ void HloMatcher::DescribeTo(::std::ostream* os) const {
|
|||
}
|
||||
}
|
||||
|
||||
bool HloParameterMatcher::MatchAndExplain(
|
||||
const HloInstruction* instruction,
|
||||
::testing::MatchResultListener* listener) const {
|
||||
if (!HloMatcher::MatchAndExplain(instruction, listener)) {
|
||||
return false;
|
||||
}
|
||||
if (instruction->parameter_number() != parameter_number_) {
|
||||
*listener << "has wrong parameter number (got "
|
||||
<< instruction->parameter_number() << ", want "
|
||||
<< parameter_number_ << ")";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloGetTupleElementMatcher::MatchAndExplain(
|
||||
const HloInstruction* instruction,
|
||||
::testing::MatchResultListener* listener) const {
|
||||
if (!HloMatcher::MatchAndExplain(instruction, listener)) {
|
||||
return false;
|
||||
}
|
||||
if (instruction->tuple_index() != tuple_index_) {
|
||||
*listener << "has wrong tuple index (got " << instruction->tuple_index()
|
||||
<< ", want " << tuple_index_ << ")";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
|
||||
void PrintTo(const HloInstruction* inst, ::std::ostream* os) {
|
||||
|
|
|
|||
|
|
@ -38,6 +38,36 @@ class HloMatcher : public ::testing::MatcherInterface<const HloInstruction*> {
|
|||
std::vector<::testing::Matcher<const HloInstruction*>> operands_;
|
||||
};
|
||||
|
||||
// Custom matcher for parameters, which accepts a parameter number.
|
||||
class HloParameterMatcher : public HloMatcher {
|
||||
public:
|
||||
explicit HloParameterMatcher(int64 parameter_number)
|
||||
: HloMatcher(HloOpcode::kParameter, /*operands=*/{}),
|
||||
parameter_number_(parameter_number) {}
|
||||
|
||||
bool MatchAndExplain(const HloInstruction* instruction,
|
||||
::testing::MatchResultListener* listener) const override;
|
||||
|
||||
private:
|
||||
int64 parameter_number_;
|
||||
};
|
||||
|
||||
// Custom matcher for get-tuple-element instructions, which accepts a tuple
|
||||
// index to match.
|
||||
class HloGetTupleElementMatcher : public HloMatcher {
|
||||
public:
|
||||
explicit HloGetTupleElementMatcher(
|
||||
::testing::Matcher<const HloInstruction*> operand, int64 tuple_index)
|
||||
: HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}),
|
||||
tuple_index_(tuple_index) {}
|
||||
|
||||
bool MatchAndExplain(const HloInstruction* instruction,
|
||||
::testing::MatchResultListener* listener) const override;
|
||||
|
||||
private:
|
||||
int64 tuple_index_;
|
||||
};
|
||||
|
||||
// HloInstruction* matchers for opcode and operands. Example:
|
||||
// namespace op = xla::opcode_matchers;
|
||||
// EXPECT_THAT(instruction,
|
||||
|
|
@ -72,7 +102,6 @@ HLO_MATCHER(Exp);
|
|||
HLO_MATCHER(Floor);
|
||||
HLO_MATCHER(Fusion);
|
||||
HLO_MATCHER(Ge);
|
||||
HLO_MATCHER(GetTupleElement);
|
||||
HLO_MATCHER(Gt);
|
||||
HLO_MATCHER(Infeed);
|
||||
HLO_MATCHER(IsFinite);
|
||||
|
|
@ -90,7 +119,6 @@ HLO_MATCHER(Ne);
|
|||
HLO_MATCHER(Negate);
|
||||
HLO_MATCHER(Outfeed);
|
||||
HLO_MATCHER(Pad);
|
||||
HLO_MATCHER(Parameter);
|
||||
HLO_MATCHER(Power);
|
||||
HLO_MATCHER(Recv);
|
||||
HLO_MATCHER(Reduce);
|
||||
|
|
@ -115,6 +143,43 @@ HLO_MATCHER(Trace);
|
|||
HLO_MATCHER(Transpose);
|
||||
HLO_MATCHER(Tuple);
|
||||
HLO_MATCHER(While);
|
||||
|
||||
// The special cases below let you check additional information about the
|
||||
// HloInstruction, beyond just its opcode and operands. In all cases you can
|
||||
// still use the generic matcher which doesn't check this info.
|
||||
//
|
||||
// Feel free to add additional custom matchers below.
|
||||
|
||||
// - Parameter(N) matches parameter number N.
|
||||
// - Parameter() matches any parameter.
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter(
|
||||
int64 parameter_number) {
|
||||
return ::testing::MakeMatcher(
|
||||
new ::xla::testing::HloParameterMatcher(parameter_number));
|
||||
}
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter() {
|
||||
return ::testing::MakeMatcher(
|
||||
new ::xla::testing::HloMatcher(HloOpcode::kParameter, {}));
|
||||
}
|
||||
|
||||
// GetTupleElement(operand, N) matches a GTE instruction which gets the N'th
|
||||
// tuple element of operand, while GetTupleElement(operand) matches any GTE
|
||||
// operation on operand, and GetTupleElement() matches any GTE operation at all.
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
|
||||
::testing::Matcher<const HloInstruction*> operand, int64 tuple_index) {
|
||||
return ::testing::MakeMatcher(
|
||||
new ::xla::testing::HloGetTupleElementMatcher(operand, tuple_index));
|
||||
}
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
|
||||
::testing::Matcher<const HloInstruction*> operand) {
|
||||
return ::testing::MakeMatcher(
|
||||
new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {operand}));
|
||||
}
|
||||
inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement() {
|
||||
return ::testing::MakeMatcher(
|
||||
new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {}));
|
||||
}
|
||||
|
||||
#undef HLO_MATCHER
|
||||
} // namespace opcode_matchers
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user