mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[XLA] Add HLO matchers that check parameter numbers and GTE indices.
This lets you do EXPECT_THAT(foo, op::Parameter(42)); and EXPECT_THAT(bar, op::GetTupleElement(baz, 8)); PiperOrigin-RevId: 174113597
This commit is contained in:
parent
f97e7c69b8
commit
4aa90bfd39
|
|
@ -73,6 +73,35 @@ void HloMatcher::DescribeTo(::std::ostream* os) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HloParameterMatcher::MatchAndExplain(
|
||||||
|
const HloInstruction* instruction,
|
||||||
|
::testing::MatchResultListener* listener) const {
|
||||||
|
if (!HloMatcher::MatchAndExplain(instruction, listener)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (instruction->parameter_number() != parameter_number_) {
|
||||||
|
*listener << "has wrong parameter number (got "
|
||||||
|
<< instruction->parameter_number() << ", want "
|
||||||
|
<< parameter_number_ << ")";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HloGetTupleElementMatcher::MatchAndExplain(
|
||||||
|
const HloInstruction* instruction,
|
||||||
|
::testing::MatchResultListener* listener) const {
|
||||||
|
if (!HloMatcher::MatchAndExplain(instruction, listener)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (instruction->tuple_index() != tuple_index_) {
|
||||||
|
*listener << "has wrong tuple index (got " << instruction->tuple_index()
|
||||||
|
<< ", want " << tuple_index_ << ")";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace testing
|
} // namespace testing
|
||||||
|
|
||||||
void PrintTo(const HloInstruction* inst, ::std::ostream* os) {
|
void PrintTo(const HloInstruction* inst, ::std::ostream* os) {
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,36 @@ class HloMatcher : public ::testing::MatcherInterface<const HloInstruction*> {
|
||||||
std::vector<::testing::Matcher<const HloInstruction*>> operands_;
|
std::vector<::testing::Matcher<const HloInstruction*>> operands_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Custom matcher for parameters, which accepts a parameter number.
|
||||||
|
class HloParameterMatcher : public HloMatcher {
|
||||||
|
public:
|
||||||
|
explicit HloParameterMatcher(int64 parameter_number)
|
||||||
|
: HloMatcher(HloOpcode::kParameter, /*operands=*/{}),
|
||||||
|
parameter_number_(parameter_number) {}
|
||||||
|
|
||||||
|
bool MatchAndExplain(const HloInstruction* instruction,
|
||||||
|
::testing::MatchResultListener* listener) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64 parameter_number_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Custom matcher for get-tuple-element instructions, which accepts a tuple
|
||||||
|
// index to match.
|
||||||
|
class HloGetTupleElementMatcher : public HloMatcher {
|
||||||
|
public:
|
||||||
|
explicit HloGetTupleElementMatcher(
|
||||||
|
::testing::Matcher<const HloInstruction*> operand, int64 tuple_index)
|
||||||
|
: HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}),
|
||||||
|
tuple_index_(tuple_index) {}
|
||||||
|
|
||||||
|
bool MatchAndExplain(const HloInstruction* instruction,
|
||||||
|
::testing::MatchResultListener* listener) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64 tuple_index_;
|
||||||
|
};
|
||||||
|
|
||||||
// HloInstruction* matchers for opcode and operands. Example:
|
// HloInstruction* matchers for opcode and operands. Example:
|
||||||
// namespace op = xla::opcode_matchers;
|
// namespace op = xla::opcode_matchers;
|
||||||
// EXPECT_THAT(instruction,
|
// EXPECT_THAT(instruction,
|
||||||
|
|
@ -72,7 +102,6 @@ HLO_MATCHER(Exp);
|
||||||
HLO_MATCHER(Floor);
|
HLO_MATCHER(Floor);
|
||||||
HLO_MATCHER(Fusion);
|
HLO_MATCHER(Fusion);
|
||||||
HLO_MATCHER(Ge);
|
HLO_MATCHER(Ge);
|
||||||
HLO_MATCHER(GetTupleElement);
|
|
||||||
HLO_MATCHER(Gt);
|
HLO_MATCHER(Gt);
|
||||||
HLO_MATCHER(Infeed);
|
HLO_MATCHER(Infeed);
|
||||||
HLO_MATCHER(IsFinite);
|
HLO_MATCHER(IsFinite);
|
||||||
|
|
@ -90,7 +119,6 @@ HLO_MATCHER(Ne);
|
||||||
HLO_MATCHER(Negate);
|
HLO_MATCHER(Negate);
|
||||||
HLO_MATCHER(Outfeed);
|
HLO_MATCHER(Outfeed);
|
||||||
HLO_MATCHER(Pad);
|
HLO_MATCHER(Pad);
|
||||||
HLO_MATCHER(Parameter);
|
|
||||||
HLO_MATCHER(Power);
|
HLO_MATCHER(Power);
|
||||||
HLO_MATCHER(Recv);
|
HLO_MATCHER(Recv);
|
||||||
HLO_MATCHER(Reduce);
|
HLO_MATCHER(Reduce);
|
||||||
|
|
@ -115,6 +143,43 @@ HLO_MATCHER(Trace);
|
||||||
HLO_MATCHER(Transpose);
|
HLO_MATCHER(Transpose);
|
||||||
HLO_MATCHER(Tuple);
|
HLO_MATCHER(Tuple);
|
||||||
HLO_MATCHER(While);
|
HLO_MATCHER(While);
|
||||||
|
|
||||||
|
// The special cases below let you check additional information about the
|
||||||
|
// HloInstruction, beyond just its opcode and operands. In all cases you can
|
||||||
|
// still use the generic matcher which doesn't check this info.
|
||||||
|
//
|
||||||
|
// Feel free to add additional custom matchers below.
|
||||||
|
|
||||||
|
// - Parameter(N) matches parameter number N.
|
||||||
|
// - Parameter() matches any parameter.
|
||||||
|
inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter(
|
||||||
|
int64 parameter_number) {
|
||||||
|
return ::testing::MakeMatcher(
|
||||||
|
new ::xla::testing::HloParameterMatcher(parameter_number));
|
||||||
|
}
|
||||||
|
inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter() {
|
||||||
|
return ::testing::MakeMatcher(
|
||||||
|
new ::xla::testing::HloMatcher(HloOpcode::kParameter, {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTupleElement(operand, N) matches a GTE instruction which gets the N'th
|
||||||
|
// tuple element of operand, while GetTupleElement(operand) matches any GTE
|
||||||
|
// operation on operand, and GetTupleElement() matches any GTE operation at all.
|
||||||
|
inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
|
||||||
|
::testing::Matcher<const HloInstruction*> operand, int64 tuple_index) {
|
||||||
|
return ::testing::MakeMatcher(
|
||||||
|
new ::xla::testing::HloGetTupleElementMatcher(operand, tuple_index));
|
||||||
|
}
|
||||||
|
inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
|
||||||
|
::testing::Matcher<const HloInstruction*> operand) {
|
||||||
|
return ::testing::MakeMatcher(
|
||||||
|
new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {operand}));
|
||||||
|
}
|
||||||
|
inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement() {
|
||||||
|
return ::testing::MakeMatcher(
|
||||||
|
new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {}));
|
||||||
|
}
|
||||||
|
|
||||||
#undef HLO_MATCHER
|
#undef HLO_MATCHER
|
||||||
} // namespace opcode_matchers
|
} // namespace opcode_matchers
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user