mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
if we detect compiled model is using cuda in meaningful way, we should store information about cuda + hardware Example: `SystemInfo(python_version='3.12.9', torch_version='2.9.0a0+gite02b0e6', cuda_version='12.6', triton_version=(3, 4), gpu_name='NVIDIA PG509-210')` Pull Request resolved: https://github.com/pytorch/pytorch/pull/162438 Approved by: https://github.com/zhxchen17
118 lines
3.5 KiB
Python
118 lines
3.5 KiB
Python
from collections import deque
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
from torch.fx import Graph, map_arg, Node
|
|
from torch.utils._ordered_set import OrderedSet
|
|
from torch.utils._pytree import tree_flatten
|
|
|
|
|
|
# flattens with support for slices
|
|
# Note: a better way to do this would
|
|
# be register/unregister slices as pytree nodes
|
|
# but there is no unregister API in the pytorch
|
|
# pytree impl
|
|
def _get_flat_args(
|
|
node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
|
|
) -> list[Node]:
|
|
args = list[Any]()
|
|
map_arg((node.args, node.kwargs), args.append)
|
|
if node in node_to_additional_deps:
|
|
args.extend(node_to_additional_deps[node])
|
|
return args
|
|
|
|
|
|
def _get_flat_args_unique(
|
|
node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
|
|
) -> OrderedSet[Node]:
|
|
args = OrderedSet[Node]()
|
|
map_arg((node.args, node.kwargs), args.add)
|
|
if node in node_to_additional_deps:
|
|
args.update(node_to_additional_deps[node])
|
|
return args
|
|
|
|
|
|
def _detect_cycles(
|
|
graph: Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]]
|
|
) -> str:
|
|
current_path: deque[Node] = deque()
|
|
current_path_set: set[Node] = set()
|
|
pending: deque[tuple[Node, Node]] = deque()
|
|
|
|
def add_to_current_path(node: Node) -> None:
|
|
current_path.append(node)
|
|
current_path_set.add(node)
|
|
|
|
def pop_current_path() -> None:
|
|
node = current_path.pop()
|
|
current_path_set.remove(node)
|
|
|
|
def current_path_head() -> Node:
|
|
return current_path[-1]
|
|
|
|
for origin in graph.find_nodes(op="output"):
|
|
current_path.clear()
|
|
current_path_set.clear()
|
|
add_to_current_path(origin)
|
|
for child in _get_flat_args_unique(origin, node_to_additional_deps):
|
|
pending.append((child, origin))
|
|
|
|
while pending:
|
|
cur_node, parent = pending.pop()
|
|
|
|
# handle backtracking
|
|
while current_path and current_path_head() != parent:
|
|
pop_current_path()
|
|
|
|
if not isinstance(cur_node, Node):
|
|
continue
|
|
|
|
if cur_node in current_path_set:
|
|
current_path.append(cur_node)
|
|
return f"cycle detected in path: {current_path}"
|
|
|
|
add_to_current_path(cur_node)
|
|
|
|
for child in _get_flat_args_unique(cur_node, node_to_additional_deps):
|
|
pending.append((child, cur_node))
|
|
|
|
return "no cycle detected"
|
|
|
|
|
|
def _graph_uses_non_cpu(graph: Optional[Graph]) -> bool:
|
|
if graph is None:
|
|
return False
|
|
|
|
def _is_non_cpu(x: Any) -> bool:
|
|
if isinstance(x, torch.device):
|
|
return x.type != "cpu"
|
|
if isinstance(x, torch.Tensor):
|
|
return x.device.type != "cpu"
|
|
return False
|
|
|
|
def _flatten_meta(node: Node, key: str) -> list[Any]:
|
|
if key not in node.meta:
|
|
return []
|
|
flat, _ = tree_flatten(node.meta[key])
|
|
return flat
|
|
|
|
for node in graph.nodes:
|
|
for key in ("val", "example_value"):
|
|
for obj in _flatten_meta(node, key):
|
|
if _is_non_cpu(obj):
|
|
return True
|
|
|
|
# Check for device conversions
|
|
if node.op == "call_method":
|
|
if node.target == "cuda":
|
|
return True
|
|
if node.target == "to" and "cuda" in node.args:
|
|
return True
|
|
|
|
# Check args/kwargs for non-CPU device specs
|
|
flat_args, _ = tree_flatten((node.args, node.kwargs))
|
|
for obj in flat_args:
|
|
if _is_non_cpu(obj):
|
|
return True
|
|
return False
|