[ez] add docblock for _iterate_exprs (#154377)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154377
Approved by: https://github.com/pianpwk
ghstack dependencies: #154374, #154375, #154376, #154386, #154401, #154404, #154405
This commit is contained in:
bobrenjc93 2025-05-27 13:19:34 -07:00 committed by PyTorch MergeBot
parent ab6cb85cb0
commit abc3fdc7ac

View File

@ -859,6 +859,24 @@ IterateExprs: TypeAlias = Union[IterateExprsAtom, Sequence[IterateExprsAtom]]
def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]:
"""
Recursively iterate through a value and yield all sympy expressions contained within it.
This function traverses various data structures (tensors, lists, tuples, etc.) and extracts
any symbolic expressions they contain. It's used for operations like finding free symbols
in complex nested structures.
Args:
val: The value to extract sympy expressions from. Can be a symbolic type (SymInt, SymFloat, SymBool),
a sympy expression, a primitive type (int, float, bool), a container (tuple, list),
a sparse tensor, a regular tensor, None, or a torch.Generator.
Yields:
sympy.Basic: Each sympy expression found in the value.
Raises:
AssertionError: If the value is of an unsupported type.
"""
if isinstance(val, SymTypes):
# This allow applies to the jagged layout NestedTensor case as
# nested ints are not symbolic