from collections import deque from typing import Any from torch.fx import Graph, map_arg, Node from torch.utils._ordered_set import OrderedSet # flattens with support for slices # Note: a better way to do this would # be register/unregister slices as pytree nodes # but there is no unregister API in the pytorch # pytree impl def _get_flat_args( node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]] ) -> list[Node]: args = list[Any]() map_arg((node.args, node.kwargs), args.append) if node in node_to_additional_deps: args.extend(node_to_additional_deps[node]) return args def _get_flat_args_unique( node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]] ) -> OrderedSet[Node]: args = OrderedSet[Node]() map_arg((node.args, node.kwargs), args.add) if node in node_to_additional_deps: args.update(node_to_additional_deps[node]) return args def _detect_cycles(graph: Graph) -> str: current_path: deque[Node] = deque() current_path_set: set[Node] = set() pending: deque[tuple[Node, Node]] = deque() def add_to_current_path(node: Node) -> None: current_path.append(node) current_path_set.add(node) def pop_current_path() -> None: node = current_path.pop() current_path_set.remove(node) def current_path_head() -> Node: return current_path[-1] for origin in graph.find_nodes(op="placeholder"): current_path.clear() current_path_set.clear() add_to_current_path(origin) for child in origin.users: pending.append((child, origin)) while pending: cur_node, parent = pending.pop() while current_path_head() != parent: pop_current_path() if cur_node in current_path_set: current_path.append(cur_node) return f"cycle detected in path: {current_path}" add_to_current_path(cur_node) for child in cur_node.users: pending.append((child, cur_node)) return "no cycle detected"