[XLA] Add a member function to check if a tuple tree has any tuples

The function returns true if a tuple has only a root node.

PiperOrigin-RevId: 825062842
This commit is contained in:
Jian Cai 2025-10-28 08:58:58 -07:00 committed by TensorFlower Gardener
parent d1ca03b626
commit 7c6d13443d
2 changed files with 11 additions and 0 deletions

View File

@ -371,6 +371,8 @@ class TupleTree {
.ok();
}
bool IsTuple() const { return nodes_.size() > 1; }
absl::Status CopyCompatibleSubtreeFrom(const TupleTree<T>& other,
const ShapeIndex& src_index,
const ShapeIndex& dst_index) {

View File

@ -747,5 +747,14 @@ TEST_F(TupleTreeTest, ToNode) {
EXPECT_THAT(tree.ToNode({0, 0}), StatusIs(absl::StatusCode::kInvalidArgument,
"Cannot index into a leaf node"));
}
TEST_F(TupleTreeTest, IsTuple) {
TupleTree<int> tuple_tree({5});
TupleTree<int> non_tuple_tree(5);
EXPECT_TRUE(tuple_tree.IsTuple());
EXPECT_FALSE(non_tuple_tree.IsTuple());
}
} // namespace
} // namespace xla