mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Most commonly CPU scalars used for philox random seed. Right now, any cpu input will skip cudagraphing the entire graph. We need both the traced graph and the runtime inputs to be cudaified. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125382 Approved by: https://github.com/jansel
60 lines
1.5 KiB
Python
60 lines
1.5 KiB
Python
from typing import Dict, Optional
|
|
|
|
import torch
|
|
|
|
from torch._logging import LazyString
|
|
|
|
|
|
def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
|
|
"""
|
|
Returns a LazyString that formats the graph code.
|
|
"""
|
|
|
|
def format_name():
|
|
if maybe_id is not None:
|
|
return f"{name} {maybe_id}"
|
|
else:
|
|
return name
|
|
|
|
if "print_output" not in kwargs:
|
|
kwargs["print_output"] = False
|
|
|
|
return LazyString(
|
|
lambda: _format_graph_code(
|
|
f"===== {format_name()} =====\n",
|
|
gm.forward.__code__.co_filename,
|
|
gm.print_readable(**kwargs),
|
|
)
|
|
)
|
|
|
|
|
|
def _format_graph_code(name, filename, graph_str):
|
|
"""
|
|
Returns a string that formats the graph code.
|
|
"""
|
|
return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
|
|
|
|
|
|
def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]:
|
|
"""
|
|
Returns the nn_module_stack of the first call_function node.
|
|
"""
|
|
for node in graph.nodes:
|
|
if node.op == "call_function" and "nn_module_stack" in node.meta:
|
|
return node.meta["nn_module_stack"]
|
|
return None
|
|
|
|
|
|
def get_node_context(node, num_nodes=2) -> str:
|
|
"""
|
|
Returns a string of the last num_nodes nodes in the graph.
|
|
"""
|
|
node_contexts = []
|
|
cur = node
|
|
for i in range(num_nodes):
|
|
node_contexts.append(cur.format_node())
|
|
if cur.op == "root":
|
|
break
|
|
cur = cur.prev
|
|
return "\n".join(node_contexts[::-1])
|