mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Builds on top of https://github.com/pytorch/pytorch/pull/163673 and https://github.com/pytorch/pytorch/pull/164174. This will be used in the followup PRs to apply regional inductor compilation. The existing implementation let Dynamo trace into the `torch.fx.traceback.annotate`, but thats not what we want. We want Dynamo to essentially run the torch.fx.traceback.annotate function in eager, so that every Fx node created in Dynamo Fx graph has the custom meta node. What does not work? * We still have to set the context manager `torch.fx.traceback.preserve_node_meta()` in the user code because CI was unhappy. This can be fixed but with some perseverance. * This does not work with graph breaks yet. But we can solve that problem, if needed, in a separate PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164678 Approved by: https://github.com/SherlockNoMad, https://github.com/jansel, https://github.com/xmfan
391 lines
12 KiB
Python
391 lines
12 KiB
Python
# mypy: allow-untyped-defs
|
|
import copy
|
|
import logging
|
|
import traceback
|
|
from contextlib import contextmanager
|
|
from enum import Enum
|
|
from typing import Any, Optional, Union
|
|
|
|
from torch._utils_internal import signpost_event
|
|
|
|
from ._compatibility import compatibility
|
|
from .graph import Graph
|
|
from .node import Node
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
__all__ = [
|
|
"annotate",
|
|
"preserve_node_meta",
|
|
"has_preserved_node_meta",
|
|
"set_stack_trace",
|
|
"set_grad_fn_seq_nr",
|
|
"reset_grad_fn_seq_nr",
|
|
"format_stack",
|
|
"set_current_meta",
|
|
"get_current_meta",
|
|
"NodeSource",
|
|
"NodeSourceAction",
|
|
"get_graph_provenance_json",
|
|
]
|
|
|
|
current_meta: dict[str, Any] = {}
|
|
should_preserve_node_meta = False
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
class NodeSourceAction(Enum):
|
|
CREATE = "create"
|
|
REPLACE = "replace"
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
class NodeSource:
|
|
"""
|
|
NodeSource is a data structure that contains the provenance information of a node.
|
|
If node `a` is created from node `b`, then `a.meta["from_node"]` may contain NodeSource(b).
|
|
"""
|
|
|
|
class NodeInfo:
|
|
def __init__(self, name: str, target: str, graph_id: int):
|
|
self.name = name
|
|
self.target = target
|
|
self.graph_id = graph_id
|
|
|
|
pass_name: str
|
|
action: list["NodeSourceAction"]
|
|
from_node: list["NodeSource"]
|
|
node_info: Optional["NodeInfo"]
|
|
_dict: Optional[dict[str, Any]]
|
|
_action_string: Optional[str]
|
|
|
|
def __init__(
|
|
self,
|
|
node: Optional[Node],
|
|
pass_name: str = "",
|
|
action: Optional[Union["NodeSourceAction", list["NodeSourceAction"]]] = None,
|
|
):
|
|
self.pass_name = pass_name
|
|
|
|
if action is None:
|
|
action = []
|
|
elif not isinstance(action, list):
|
|
action = [action]
|
|
for a in action:
|
|
assert isinstance(a, NodeSourceAction)
|
|
self.action = action
|
|
if node:
|
|
self.node_info = self.NodeInfo(
|
|
name=node.name, target=str(node.target), graph_id=id(node.graph)
|
|
)
|
|
self.from_node = (
|
|
copy.deepcopy(node.meta["from_node"])
|
|
if "from_node" in node.meta
|
|
else []
|
|
)
|
|
else:
|
|
self.node_info = None
|
|
self.from_node = []
|
|
|
|
# cache the action string and dict representation for performance.
|
|
self._action_string: Optional[str] = None
|
|
self._dict: Optional[dict[str, Any]] = None
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self.node_info.name if self.node_info else ""
|
|
|
|
@property
|
|
def target(self) -> str:
|
|
return self.node_info.target if self.node_info else ""
|
|
|
|
@property
|
|
def graph_id(self) -> int:
|
|
return self.node_info.graph_id if self.node_info else -1
|
|
|
|
def __repr__(self):
|
|
return self.print_readable()
|
|
|
|
def _get_action_string(self):
|
|
if self._action_string is None:
|
|
self._action_string = "+".join([a.name.lower() for a in self.action])
|
|
return self._action_string
|
|
|
|
def print_readable(self, indent=0):
|
|
if indent > 9:
|
|
return ""
|
|
result = ""
|
|
action_string = self._get_action_string()
|
|
result += (
|
|
" " * indent * 4
|
|
+ f"(name={self.name}, pass_name={self.pass_name}, action={action_string}, graph_id={self.graph_id})\n"
|
|
)
|
|
for item in self.from_node:
|
|
result += item.print_readable(indent + 1)
|
|
return result
|
|
|
|
def to_dict(self) -> dict:
|
|
if self._dict is None:
|
|
# Convert the object to a dictionary
|
|
action_string = self._get_action_string()
|
|
self._dict = {
|
|
"name": self.name,
|
|
"target": self.target,
|
|
"graph_id": self.graph_id,
|
|
"pass_name": self.pass_name,
|
|
"action": action_string,
|
|
"from_node": [node.to_dict() for node in self.from_node],
|
|
}
|
|
|
|
assert self._dict is not None
|
|
return self._dict
|
|
|
|
def __eq__(self, other: object):
|
|
if not isinstance(other, NodeSource):
|
|
return False
|
|
return self.to_dict() == other.to_dict()
|
|
|
|
def __hash__(self):
|
|
# Create a hash based on the dictionary representation
|
|
# We need to convert the dict to a hashable form
|
|
def _make_hashable(obj):
|
|
if isinstance(obj, dict):
|
|
return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items()))
|
|
elif isinstance(obj, list):
|
|
return tuple(_make_hashable(item) for item in obj)
|
|
else:
|
|
return obj
|
|
|
|
return hash(_make_hashable(self.to_dict()))
|
|
|
|
@classmethod
|
|
def _from_dict(cls, d: Optional[dict]) -> Optional["NodeSource"]:
|
|
"""
|
|
Recursively deserialize from_node metadata from dictionary data.
|
|
It is used to deserialize the from_node field from serialized metadata.
|
|
Please use constructor NodeSource(node, ...) to create a NodeSource object.
|
|
"""
|
|
if d is None:
|
|
return None
|
|
|
|
assert isinstance(d, dict), f"Expected a dict, got {type(d)}"
|
|
|
|
# Create a NodeSource object directly without going through the constructor
|
|
# to avoid issues with graph ID and node creation
|
|
node_source = NodeSource.__new__(NodeSource)
|
|
|
|
# Reset the cached properties
|
|
node_source._action_string = None
|
|
node_source._dict = None
|
|
|
|
# Set the basic attributes
|
|
node_source.pass_name = d.get("pass_name", "")
|
|
|
|
# Parse action string back to NodeSourceAction enum list
|
|
action_str = d.get("action", "")
|
|
actions = []
|
|
if action_str:
|
|
for action_name in action_str.split("+"):
|
|
if action_name.upper() == "CREATE":
|
|
actions.append(NodeSourceAction.CREATE)
|
|
elif action_name.upper() == "REPLACE":
|
|
actions.append(NodeSourceAction.REPLACE)
|
|
node_source.action = actions
|
|
|
|
# Create the NodeInfo object directly
|
|
if "name" in d and "target" in d and "graph_id" in d:
|
|
node_info = NodeSource.NodeInfo(
|
|
d.get("name", ""), d.get("target", ""), d.get("graph_id", -1)
|
|
)
|
|
node_source.node_info = node_info
|
|
else:
|
|
node_source.node_info = None
|
|
|
|
# Recursively deserialize nested from_node
|
|
if d.get("from_node", None) is not None:
|
|
node_source.from_node = [
|
|
result
|
|
for fn in d.get("from_node", [])
|
|
if (result := cls._from_dict(fn)) is not None
|
|
]
|
|
else:
|
|
node_source.from_node = []
|
|
return node_source
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
@contextmanager
|
|
def preserve_node_meta(enable=True):
|
|
global should_preserve_node_meta
|
|
global current_meta
|
|
saved_should_preserve_node_meta = should_preserve_node_meta
|
|
# Shallow copy is OK since fields of current_meta are not mutated
|
|
saved_current_meta = current_meta.copy()
|
|
try:
|
|
should_preserve_node_meta = enable
|
|
yield
|
|
finally:
|
|
should_preserve_node_meta = saved_should_preserve_node_meta
|
|
current_meta = saved_current_meta
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def set_stack_trace(stack: list[str]):
|
|
global current_meta
|
|
|
|
if should_preserve_node_meta and stack:
|
|
current_meta["stack_trace"] = "".join(stack)
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
@contextmanager
|
|
def annotate(annotation_dict: dict):
|
|
"""
|
|
Temporarily adds custom annotations to the current tracing context.
|
|
The fx_node produced from this tracing context will have the
|
|
custom annotations in node.metadata["custom"] field.
|
|
|
|
This context manager allows you to insert arbitrary metadata into the PT2
|
|
tracing system by updating the global `current_meta["custom"]` dictionary.
|
|
The annotations are automatically reverted after the context exits.
|
|
|
|
This is intended for advanced users who need to attach additional metadata to the fx nodes
|
|
(e.g., for debugging, analysis, or external tooling) during export tracing.
|
|
|
|
Note:
|
|
This API is **not backward compatible** and may evolve in future releases.
|
|
|
|
Note:
|
|
This API is not compatible with fx.symbolic_trace or jit.trace. It's intended
|
|
to be used with PT2 family of tracers, e.g. torch.export and dynamo.
|
|
|
|
Args:
|
|
annotation_dict (dict): A dictionary of custom key-value pairs to inject
|
|
into the FX trace metadata.
|
|
|
|
Example:
|
|
>>> with annotate({"source": "custom_pass", "tag": 42}):
|
|
... # compute here
|
|
# After exiting the context, custom annotations are removed.
|
|
"""
|
|
|
|
global current_meta
|
|
|
|
has_custom = "custom" in current_meta
|
|
old_custom = copy.copy(current_meta.get("custom", {}))
|
|
|
|
try:
|
|
if not has_custom:
|
|
current_meta["custom"] = {}
|
|
|
|
# Update with all key-value pairs from the input dict
|
|
current_meta["custom"].update(annotation_dict)
|
|
yield
|
|
finally:
|
|
if has_custom:
|
|
# Restore the original custom dict
|
|
current_meta["custom"] = old_custom
|
|
else:
|
|
del current_meta["custom"]
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def set_grad_fn_seq_nr(seq_nr):
|
|
global current_meta
|
|
|
|
if should_preserve_node_meta:
|
|
# The seq_nr is captured by eager mode in the grad_fn during forward
|
|
current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [
|
|
seq_nr
|
|
]
|
|
current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def reset_grad_fn_seq_nr():
|
|
# NB: reset state properly, this would be helpful towards supporting
|
|
# reentrant autograd if we actually wanted to do that.
|
|
global current_meta
|
|
if should_preserve_node_meta:
|
|
current_level = current_meta.get("in_grad_fn", 0)
|
|
assert current_level > 0
|
|
if current_level == 1:
|
|
del current_meta["in_grad_fn"]
|
|
del current_meta["grad_fn_seq_nr"]
|
|
else:
|
|
current_meta["in_grad_fn"] = current_level - 1
|
|
current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1]
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def format_stack() -> list[str]:
|
|
if should_preserve_node_meta:
|
|
return [current_meta.get("stack_trace", "")]
|
|
else:
|
|
# fallback to traceback.format_stack()
|
|
return traceback.format_list(traceback.extract_stack()[:-1])
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def has_preserved_node_meta() -> bool:
|
|
return should_preserve_node_meta
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
@contextmanager
|
|
def set_current_meta(node, pass_name=""):
|
|
global current_meta
|
|
if should_preserve_node_meta and node.meta:
|
|
saved_meta = current_meta
|
|
try:
|
|
current_meta = node.meta.copy()
|
|
|
|
# Update the "from_node" field in current_meta for provenance tracking.
|
|
# Instead of appending, overwrite the "from_node" field because current_meta
|
|
# will be assigned to the new node. The new NodeSource(node, ...) will
|
|
# include the information from the previous current_meta["from_node"].
|
|
current_meta["from_node"] = [
|
|
NodeSource(node, pass_name, NodeSourceAction.CREATE)
|
|
]
|
|
yield
|
|
finally:
|
|
current_meta = saved_meta
|
|
else:
|
|
yield
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def get_current_meta() -> dict[str, Any]:
|
|
return current_meta
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
|
|
"""
|
|
Given an fx.Graph, return a json that contains the provenance information of each node.
|
|
"""
|
|
try:
|
|
provenance_tracking_json = {}
|
|
for node in graph.nodes:
|
|
if node.op == "call_function":
|
|
provenance_tracking_json[node.name] = (
|
|
[source.to_dict() for source in node.meta["from_node"]]
|
|
if "from_node" in node.meta
|
|
else []
|
|
)
|
|
return provenance_tracking_json
|
|
except Exception as e:
|
|
# Since this is just debugging, it should never interfere with regular
|
|
# program execution, so we use this try-except to guard against any error
|
|
signpost_event(
|
|
"inductor",
|
|
"provenance_tracking_error",
|
|
{
|
|
"function": "get_graph_provenance_json",
|
|
"error_msg": str(e),
|
|
"stack_trace": traceback.format_exc(),
|
|
},
|
|
)
|
|
return {}
|