pytorch/torch/fx/traceback.py
Shangdi Yu aa99e0958f Separate provenance tracking to different levels (#160383)
Summary: as title. We've got request from various parties who are interested in turning on the provenance tracking by default. In this PR, we prepare to turn on part of the provenance tracking that doesn't have too much overhead by default.

- Change `provenance_tracking` config to `provenance_tracking_level`
- turn on the following provenance tracking by default when `basic_provenance_tracking`=True
    - `set_kernel_post_grad_provenance_tracing` for kernels, this add mapping between triton kernels and post_grad nodes
    - `dump_inductor_provenance_info` if we're dumping tlparse log
    - `get_graph_provenance_json` and dump `reate_mapping_pre_post_grad_nodes`. This creates mapping between pre_grad and post_grad nodes. Since we're not turning on the provenance tracking in GraphTransformObserver by default, the mapping here maybe incomplete/limited.
    - add stack trace from post grad nodes to inductor IR nodes
    - add exception swallowing for all functions above

Test Plan:
CI

Rollback Plan:

Differential Revision: D80031559

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160383
Approved by: https://github.com/angelayi
2025-08-15 04:59:35 +00:00

334 lines
11 KiB
Python

# mypy: allow-untyped-defs
import copy
import logging
import traceback
from contextlib import contextmanager
from enum import Enum
from typing import Any, Optional, Union
from ._compatibility import compatibility
from .graph import Graph
from .node import Node
log = logging.getLogger(__name__)
__all__ = [
"preserve_node_meta",
"has_preserved_node_meta",
"set_stack_trace",
"set_grad_fn_seq_nr",
"reset_grad_fn_seq_nr",
"format_stack",
"set_current_meta",
"get_current_meta",
"NodeSource",
"NodeSourceAction",
"get_graph_provenance_json",
]
current_meta: dict[str, Any] = {}
should_preserve_node_meta = False
@compatibility(is_backward_compatible=False)
class NodeSourceAction(Enum):
CREATE = "create"
REPLACE = "replace"
@compatibility(is_backward_compatible=False)
class NodeSource:
"""
NodeSource is a data structure that contains the provenance information of a node.
If node `a` is created from node `b`, then `a.meta["from_node"]` may contain NodeSource(b).
"""
class NodeInfo:
def __init__(self, name: str, target: str, graph_id: int):
self.name = name
self.target = target
self.graph_id = graph_id
pass_name: str
action: list["NodeSourceAction"]
from_node: list["NodeSource"]
node_info: Optional["NodeInfo"]
_dict: Optional[dict[str, Any]]
_action_string: Optional[str]
def __init__(
self,
node: Optional[Node],
pass_name: str = "",
action: Optional[Union["NodeSourceAction", list["NodeSourceAction"]]] = None,
):
self.pass_name = pass_name
if action is None:
action = []
elif not isinstance(action, list):
action = [action]
for a in action:
assert isinstance(a, NodeSourceAction)
self.action = action
if node:
self.node_info = self.NodeInfo(
name=node.name, target=str(node.target), graph_id=id(node.graph)
)
self.from_node = (
copy.deepcopy(node.meta["from_node"])
if "from_node" in node.meta
else []
)
else:
self.node_info = None
self.from_node = []
# cache the action string and dict representation for performance.
self._action_string: Optional[str] = None
self._dict: Optional[dict[str, Any]] = None
@property
def name(self) -> str:
return self.node_info.name if self.node_info else ""
@property
def target(self) -> str:
return self.node_info.target if self.node_info else ""
@property
def graph_id(self) -> int:
return self.node_info.graph_id if self.node_info else -1
def __repr__(self):
return self.print_readable()
def _get_action_string(self):
if self._action_string is None:
self._action_string = "+".join([a.name.lower() for a in self.action])
return self._action_string
def print_readable(self, indent=0):
if indent > 9:
return ""
result = ""
action_string = self._get_action_string()
result += (
" " * indent * 4
+ f"(name={self.name}, pass_name={self.pass_name}, action={action_string}, graph_id={self.graph_id})\n"
)
for item in self.from_node:
result += item.print_readable(indent + 1)
return result
def to_dict(self) -> dict:
if self._dict is None:
# Convert the object to a dictionary
action_string = self._get_action_string()
self._dict = {
"name": self.name,
"target": self.target,
"graph_id": self.graph_id,
"pass_name": self.pass_name,
"action": action_string,
"from_node": [node.to_dict() for node in self.from_node],
}
assert self._dict is not None
return self._dict
def __eq__(self, other: object):
if not isinstance(other, NodeSource):
return False
return self.to_dict() == other.to_dict()
def __hash__(self):
# Create a hash based on the dictionary representation
# We need to convert the dict to a hashable form
def _make_hashable(obj):
if isinstance(obj, dict):
return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items()))
elif isinstance(obj, list):
return tuple(_make_hashable(item) for item in obj)
else:
return obj
return hash(_make_hashable(self.to_dict()))
@classmethod
def _from_dict(cls, d: Optional[dict]) -> Optional["NodeSource"]:
"""
Recursively deserialize from_node metadata from dictionary data.
It is used to deserialize the from_node field from serialized metadata.
Please use constructor NodeSource(node, ...) to create a NodeSource object.
"""
if d is None:
return None
assert isinstance(d, dict), f"Expected a dict, got {type(d)}"
# Create a NodeSource object directly without going through the constructor
# to avoid issues with graph ID and node creation
node_source = NodeSource.__new__(NodeSource)
# Reset the cached properties
node_source._action_string = None
node_source._dict = None
# Set the basic attributes
node_source.pass_name = d.get("pass_name", "")
# Parse action string back to NodeSourceAction enum list
action_str = d.get("action", "")
actions = []
if action_str:
for action_name in action_str.split("+"):
if action_name.upper() == "CREATE":
actions.append(NodeSourceAction.CREATE)
elif action_name.upper() == "REPLACE":
actions.append(NodeSourceAction.REPLACE)
node_source.action = actions
# Create the NodeInfo object directly
if "name" in d and "target" in d and "graph_id" in d:
node_info = NodeSource.NodeInfo(
d.get("name", ""), d.get("target", ""), d.get("graph_id", -1)
)
node_source.node_info = node_info
else:
node_source.node_info = None
# Recursively deserialize nested from_node
if d.get("from_node", None) is not None:
node_source.from_node = [
result
for fn in d.get("from_node", [])
if (result := cls._from_dict(fn)) is not None
]
else:
node_source.from_node = []
return node_source
@compatibility(is_backward_compatible=False)
@contextmanager
def preserve_node_meta(enable=True):
global should_preserve_node_meta
global current_meta
# If enable is False, this context manager is a no-op
if not enable:
yield
else:
saved_should_preserve_node_meta = should_preserve_node_meta
# Shallow copy is OK since fields of current_meta are not mutated
saved_current_meta = current_meta.copy()
try:
should_preserve_node_meta = True
yield
finally:
should_preserve_node_meta = saved_should_preserve_node_meta
current_meta = saved_current_meta
@compatibility(is_backward_compatible=False)
def set_stack_trace(stack: list[str]):
global current_meta
if should_preserve_node_meta and stack:
current_meta["stack_trace"] = "".join(stack)
@compatibility(is_backward_compatible=False)
def set_grad_fn_seq_nr(seq_nr):
global current_meta
if should_preserve_node_meta:
# The seq_nr is captured by eager mode in the grad_fn during forward
current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [
seq_nr
]
current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1
@compatibility(is_backward_compatible=False)
def reset_grad_fn_seq_nr():
# NB: reset state properly, this would be helpful towards supporting
# reentrant autograd if we actually wanted to do that.
global current_meta
if should_preserve_node_meta:
current_level = current_meta.get("in_grad_fn", 0)
assert current_level > 0
if current_level == 1:
del current_meta["in_grad_fn"]
del current_meta["grad_fn_seq_nr"]
else:
current_meta["in_grad_fn"] = current_level - 1
current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1]
@compatibility(is_backward_compatible=False)
def format_stack() -> list[str]:
if should_preserve_node_meta:
return [current_meta.get("stack_trace", "")]
else:
# fallback to traceback.format_stack()
return traceback.format_list(traceback.extract_stack()[:-1])
@compatibility(is_backward_compatible=False)
def has_preserved_node_meta() -> bool:
return should_preserve_node_meta
@compatibility(is_backward_compatible=False)
@contextmanager
def set_current_meta(node, pass_name=""):
global current_meta
if should_preserve_node_meta and node.meta:
saved_meta = current_meta
try:
current_meta = node.meta.copy()
# Update the "from_node" field in current_meta for provenance tracking.
# Instead of appending, overwrite the "from_node" field because current_meta
# will be assigned to the new node. The new NodeSource(node, ...) will
# include the information from the previous current_meta["from_node"].
current_meta["from_node"] = [
NodeSource(node, pass_name, NodeSourceAction.CREATE)
]
yield
finally:
current_meta = saved_meta
else:
yield
@compatibility(is_backward_compatible=False)
def get_current_meta() -> dict[str, Any]:
return current_meta
@compatibility(is_backward_compatible=False)
def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
"""
Given an fx.Graph, return a json that contains the provenance information of each node.
"""
try:
provenance_tracking_json = {}
for node in graph.nodes:
if node.op == "call_function":
provenance_tracking_json[node.name] = (
[source.to_dict() for source in node.meta["from_node"]]
if "from_node" in node.meta
else []
)
return provenance_tracking_json
except Exception as e:
# Since this is just debugging, it should never interfere with regular
# program execution, so we use this try-except to guard against any error
# TODO: log the error to scuba table for better signal
log.error("Unexpected error in get_graph_provenance_json: %s", e)
log.error(traceback.format_exc())
return {}