pytorch/torch/fx/_utils.py
Simon Fan 00c6ca4459 [compiled autograd][cudagraphs] Inputs runtime wrapper to move cpu scalars to cuda (#125382)
Most commonly CPU scalars used for philox random seed. Right now, any cpu input will skip cudagraphing the entire graph. We need both the traced graph and the runtime inputs to be cudaified.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125382
Approved by: https://github.com/jansel
2024-06-07 07:12:46 +00:00

60 lines
1.5 KiB
Python

from typing import Dict, Optional
import torch
from torch._logging import LazyString
def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
"""
Returns a LazyString that formats the graph code.
"""
def format_name():
if maybe_id is not None:
return f"{name} {maybe_id}"
else:
return name
if "print_output" not in kwargs:
kwargs["print_output"] = False
return LazyString(
lambda: _format_graph_code(
f"===== {format_name()} =====\n",
gm.forward.__code__.co_filename,
gm.print_readable(**kwargs),
)
)
def _format_graph_code(name, filename, graph_str):
"""
Returns a string that formats the graph code.
"""
return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]:
"""
Returns the nn_module_stack of the first call_function node.
"""
for node in graph.nodes:
if node.op == "call_function" and "nn_module_stack" in node.meta:
return node.meta["nn_module_stack"]
return None
def get_node_context(node, num_nodes=2) -> str:
"""
Returns a string of the last num_nodes nodes in the graph.
"""
node_contexts = []
cur = node
for i in range(num_nodes):
node_contexts.append(cur.format_node())
if cur.op == "root":
break
cur = cur.prev
return "\n".join(node_contexts[::-1])