mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR moves insert_deferred_runtime_asserts from dynamo to torch.fx.passes and uses it to add runtime assertion for non-strict export. Differential Revision: D55944267 Pull Request resolved: https://github.com/pytorch/pytorch/pull/123681 Approved by: https://github.com/tugsbayasgalan, https://github.com/angelayi
57 lines
1.4 KiB
Python
57 lines
1.4 KiB
Python
from typing import Dict, Optional
|
|
|
|
import torch
|
|
|
|
from torch._logging import LazyString
|
|
|
|
|
|
def lazy_format_graph_code(name, gm, maybe_id=None):
|
|
"""
|
|
Returns a LazyString that formats the graph code.
|
|
"""
|
|
|
|
def format_name():
|
|
if maybe_id is not None:
|
|
return f"{name} {maybe_id}"
|
|
else:
|
|
return name
|
|
|
|
return LazyString(
|
|
lambda: _format_graph_code(
|
|
f"===== {format_name()} =====\n",
|
|
gm.forward.__code__.co_filename,
|
|
gm.print_readable(print_output=False),
|
|
)
|
|
)
|
|
|
|
|
|
def _format_graph_code(name, filename, graph_str):
|
|
"""
|
|
Returns a string that formats the graph code.
|
|
"""
|
|
return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
|
|
|
|
|
|
def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]:
|
|
"""
|
|
Returns the nn_module_stack of the first call_function node.
|
|
"""
|
|
for node in graph.nodes:
|
|
if node.op == "call_function" and "nn_module_stack" in node.meta:
|
|
return node.meta["nn_module_stack"]
|
|
return None
|
|
|
|
|
|
def get_node_context(node, num_nodes=2) -> str:
|
|
"""
|
|
Returns a string of the last num_nodes nodes in the graph.
|
|
"""
|
|
node_contexts = []
|
|
cur = node
|
|
for i in range(num_nodes):
|
|
node_contexts.append(cur.format_node())
|
|
if cur.op == "root":
|
|
break
|
|
cur = cur.prev
|
|
return "\n".join(node_contexts[::-1])
|