mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ez] add docblock for _iterate_exprs (#154377)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154377 Approved by: https://github.com/pianpwk ghstack dependencies: #154374, #154375, #154376, #154386, #154401, #154404, #154405
This commit is contained in:
parent
ab6cb85cb0
commit
abc3fdc7ac
|
|
@ -859,6 +859,24 @@ IterateExprs: TypeAlias = Union[IterateExprsAtom, Sequence[IterateExprsAtom]]
|
|||
|
||||
|
||||
def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]:
|
||||
"""
|
||||
Recursively iterate through a value and yield all sympy expressions contained within it.
|
||||
|
||||
This function traverses various data structures (tensors, lists, tuples, etc.) and extracts
|
||||
any symbolic expressions they contain. It's used for operations like finding free symbols
|
||||
in complex nested structures.
|
||||
|
||||
Args:
|
||||
val: The value to extract sympy expressions from. Can be a symbolic type (SymInt, SymFloat, SymBool),
|
||||
a sympy expression, a primitive type (int, float, bool), a container (tuple, list),
|
||||
a sparse tensor, a regular tensor, None, or a torch.Generator.
|
||||
|
||||
Yields:
|
||||
sympy.Basic: Each sympy expression found in the value.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the value is of an unsupported type.
|
||||
"""
|
||||
if isinstance(val, SymTypes):
|
||||
# This allow applies to the jagged layout NestedTensor case as
|
||||
# nested ints are not symbolic
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user