Add stack trace to Inductor IR nodes if inductor.config.trace.provenance_tracing=True (#158576)

Summary:
- Split `create_mapping` to `create_mapping_pre_post_grad_nodes` and  ` create_node_mapping_kernel_to_post_grad`
- Store a mapping from pre_grad graph node names to stack traces in `_inductor_pre_grad_node_stack_trace`
- Add `stack_traces` member to ir.Node and add it to the string representation of ir.Node
- When we create an IR node, if `inductor.config.trace.provenance_tracing=True`, we populate `stack_traces` from `origins`. The nodes in `origins` are post_grad graph nodes. If a node has `node.stack_trace`, we store the stack_trace directly. This is particularly important for backward graph nodes because they don't have a mapping to pre-grad graph nodes. If a node doesn't have `.stack_trace ` (such as `linear`-> `addmm` nodes), we use the stack trace of the pre_grad graph nodes that it maps to.
  - A post grad graph node might not have stack trace if it correspond to multiple pre grad graph nodes, e.g. [GroupLinearFusion](a00442421a/torch/_inductor/fx_passes/group_batch_fusion.py (L299))

Example:

```
scheduling ExternKernelOut(
  python_kernel_name='extern_kernels.mm',
  name=buf0,
  layout=FixedLayout('cuda:0', torch.float32, size=[8, 16], stride=[16, 1]),
  inputs=[InputBuffer(name='arg2_1', layout=FixedLayout('cuda:0', torch.float32, size=[8, 10], stride=[10, 1])), ReinterpretView(
    StorageBox(
      ConstantBuffer(name='fc1_weight', layout=FixedLayout('cuda:0', torch.float32, size=[16, 10], stride=[10, 1]))
    ),
    FixedLayout('cuda:0', torch.float32, size=[10, 16], stride=[1, 10]),
    origins=OrderedSet([mm_default_1]),
    stack_traces = {,
    File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/scripts/shangdiy/aot.py", line 29, in forward,
        x = self.fc1(x),
      File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/torch/nn/modules/linear.py", line 125, in forward,
        return F.linear(input, self.weight, self.bias),
    }
  )],
  constant_args=(),
  kwargs={},
  output_view=None,
  python_kernel_name=extern_kernels.mm,
  cpp_kernel_name=at::mm_out,
  ordered_kwargs_for_cpp_kernel=(),
  op_overload=None,
  arg_properties=[{}, {}],
  allarg_properties={},
  kwarg_properties=None,
  unbacked_bindings={},
  mutation_outputs=[],
  origin_node=mm_default_1,
  origins=OrderedSet([mm_default_1]),
  stack_traces = {,
  File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/scripts/shangdiy/aot.py", line 29, in forward,
      x = self.fc1(x),
    File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/torch/nn/modules/linear.py", line 125, in forward,
      return F.linear(input, self.weight, self.bias),
  }
)
```

Test Plan:
```
buck2 run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing
```

Rollback Plan:

Differential Revision: D78365534

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158576
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu 2025-07-18 04:05:17 +00:00 committed by PyTorch MergeBot
parent 86dbc0ef67
commit 1e86fa2e5b
4 changed files with 143 additions and 46 deletions

View File

@ -10,7 +10,10 @@ from pathlib import Path
import torch import torch
from torch._inductor import config from torch._inductor import config
from torch._inductor.debug import create_node_mapping from torch._inductor.debug import (
create_mapping_pre_post_grad_nodes,
create_node_mapping_kernel_to_post_grad,
)
from torch._inductor.test_case import run_tests, TestCase from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.inductor_utils import HAS_GPU from torch.testing._internal.inductor_utils import HAS_GPU
from torch.testing._internal.triton_utils import requires_cuda from torch.testing._internal.triton_utils import requires_cuda
@ -386,11 +389,17 @@ class TestProvenanceTracingNodeMapping(TestCase):
"triton_poi_fused_addmm_relu_sigmoid_0": ["relu", "add_tensor"] "triton_poi_fused_addmm_relu_sigmoid_0": ["relu", "add_tensor"]
} }
result = create_node_mapping( result = create_mapping_pre_post_grad_nodes(
pre_grad_graph_id, pre_grad_graph_id,
post_to_pre_grad_nodes_json, post_to_pre_grad_nodes_json,
triton_kernel_to_post_grad_json,
) )
result = {
**result,
**create_node_mapping_kernel_to_post_grad(
triton_kernel_to_post_grad_json,
),
}
self.assertEqual( self.assertEqual(
result, result,
{ {

View File

@ -1033,17 +1033,13 @@ def _compile_fx_inner(
provenance_info = torch._inductor.debug.dump_inductor_provenance_info() provenance_info = torch._inductor.debug.dump_inductor_provenance_info()
# provenance_info might be None if trace.provenance_tracking is not set # provenance_info might be None if trace.provenance_tracking is not set
if provenance_info: if provenance_info:
(
_,
node_mappings,
) = provenance_info
trace_structured( trace_structured(
"artifact", "artifact",
metadata_fn=lambda: { metadata_fn=lambda: {
"name": "inductor_provenance_tracking_node_mappings", "name": "inductor_provenance_tracking_node_mappings",
"encoding": "json", "encoding": "json",
}, },
payload_fn=lambda: json.dumps(node_mappings), payload_fn=lambda: json.dumps(provenance_info),
) )
# This message is for printing overview information of inductor mm counts, shapes,etc after lowering # This message is for printing overview information of inductor mm counts, shapes,etc after lowering
@ -1299,8 +1295,13 @@ class _InProcessFxCompile(FxCompile):
}, },
payload_fn=lambda: json.dumps(provenance_tracking_json), payload_fn=lambda: json.dumps(provenance_tracking_json),
) )
from torch._inductor.debug import create_mapping_pre_post_grad_nodes
torch._inductor.debug._inductor_post_to_pre_grad_nodes = ( torch._inductor.debug._inductor_post_to_pre_grad_nodes = (
provenance_tracking_json create_mapping_pre_post_grad_nodes(
torch._inductor.debug._pre_grad_graph_id,
provenance_tracking_json,
)
) )
metrics_context = get_metrics_context() metrics_context = get_metrics_context()
@ -2174,6 +2175,13 @@ def compile_fx(
) )
torch._inductor.debug._pre_grad_graph_id = id(model_.graph) torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
if config.trace.provenance_tracking:
for node in model_.graph.nodes:
if node.stack_trace:
torch._inductor.debug._inductor_pre_grad_node_stack_trace[
node.name
] = node.stack_trace
model_ = _recursive_pre_grad_passes(model_, example_inputs_) model_ = _recursive_pre_grad_passes(model_, example_inputs_)
trace_structured( trace_structured(
"artifact", "artifact",

View File

@ -316,6 +316,7 @@ def enable_aot_logging() -> Iterator[None]:
_inductor_post_to_pre_grad_nodes: dict[str, Any] = {} _inductor_post_to_pre_grad_nodes: dict[str, Any] = {}
_inductor_triton_kernel_to_post_grad_node_info: dict[str, Any] = {} _inductor_triton_kernel_to_post_grad_node_info: dict[str, Any] = {}
_pre_grad_graph_id: Optional[int] = None _pre_grad_graph_id: Optional[int] = None
_inductor_pre_grad_node_stack_trace: dict[str, str] = {}
@contextlib.contextmanager @contextlib.contextmanager
@ -701,23 +702,18 @@ class TensorMetadataHolder:
save_args_cnt = itertools.count() save_args_cnt = itertools.count()
def create_node_mapping( def create_mapping_pre_post_grad_nodes(
pre_grad_graph_id: int, pre_grad_graph_id: Optional[int],
post_to_pre_grad_nodes_json: dict[str, Any], post_to_pre_grad_nodes_json: dict[str, Any],
triton_kernel_to_post_grad_json: dict[str, Any],
) -> dict[str, dict[str, Any]]: ) -> dict[str, dict[str, Any]]:
"""Create bidirectional mappings between:
- pre_grad graph nodes and post_grad graph code nodes, and vice versa
- triton kernel name and post_grad graph code nodes, and vice versa
""" """
Create bidirectional mappings between pre_grad graph nodes
and post_grad graph code nodes, and vice versa.
"""
# return a dummy dict if there's any error # return a dummy dict if there's any error
empty_return: dict[str, dict[str, Any]] = { empty_return: dict[str, dict[str, Any]] = {
"preToPost": {}, "preToPost": {},
"postToPre": {}, "postToPre": {},
"cppCodeToPost": {},
"postToCppCode": {},
} }
log.info("Creating node mappings for provenance tracking") log.info("Creating node mappings for provenance tracking")
@ -726,12 +722,6 @@ def create_node_mapping(
log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict") log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict")
return empty_return return empty_return
if not isinstance(triton_kernel_to_post_grad_json, dict):
log.error(
"Provenance tacking error: triton_kernel_to_post_grad_json is not a dict"
)
return empty_return
if not isinstance(pre_grad_graph_id, int): if not isinstance(pre_grad_graph_id, int):
log.error("Provenance tacking error: pre_grad_graph_id is not an int") log.error("Provenance tacking error: pre_grad_graph_id is not an int")
return empty_return return empty_return
@ -739,17 +729,7 @@ def create_node_mapping(
pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet) pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet)
post_to_pre: dict[str, Any] = collections.defaultdict(OrderedSet) post_to_pre: dict[str, Any] = collections.defaultdict(OrderedSet)
post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet)
try: try:
for outer_key, node_array in triton_kernel_to_post_grad_json.items():
if not isinstance(node_array, list):
log.error(
"Provenance tacking error: triton_kernel_to_post_grad_json value is not a list"
)
return empty_return
for curr_node in node_array:
post_to_cpp_code[curr_node].add(outer_key)
def check_format(node: dict[str, Any]) -> bool: def check_format(node: dict[str, Any]) -> bool:
if not isinstance(node, dict): if not isinstance(node, dict):
@ -799,10 +779,61 @@ def create_node_mapping(
# convert to list because set is not JSON serializable # convert to list because set is not JSON serializable
convert_sets_to_lists(pre_to_post) convert_sets_to_lists(pre_to_post)
convert_sets_to_lists(post_to_pre) convert_sets_to_lists(post_to_pre)
convert_sets_to_lists(post_to_cpp_code)
return { return {
"preToPost": pre_to_post, "preToPost": pre_to_post,
"postToPre": post_to_pre, "postToPre": post_to_pre,
}
except Exception as e:
# Since this is just logging code, it should never interfere with regular
# program execution, so we use this try-except to guard against any error
log.error("Unexpected error in create_node_mapping: %s", e)
log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json)
log.error("pre_grad_graph_id: %s", pre_grad_graph_id)
log.error(traceback.format_exc())
return empty_return
def create_node_mapping_kernel_to_post_grad(
triton_kernel_to_post_grad_json: dict[str, Any],
) -> dict[str, dict[str, Any]]:
"""Create bidirectional mappings between triton kernel name and post_grad
graph code nodes, and vice versa.
"""
# return a dummy dict if there's any error
empty_return: dict[str, dict[str, Any]] = {
"cppCodeToPost": {},
"postToCppCode": {},
}
log.info("Creating node mappings for provenance tracking")
if not isinstance(triton_kernel_to_post_grad_json, dict):
log.error(
"Provenance tacking error: triton_kernel_to_post_grad_json is not a dict"
)
return empty_return
post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet)
try:
for outer_key, node_array in triton_kernel_to_post_grad_json.items():
if not isinstance(node_array, list):
log.error(
"Provenance tacking error: triton_kernel_to_post_grad_json value is not a list"
)
return empty_return
for curr_node in node_array:
post_to_cpp_code[curr_node].add(outer_key)
def convert_sets_to_lists(d: dict[str, Any]) -> None:
for key in d:
d[key] = list(d[key])
d = dict(d)
# convert to list because set is not JSON serializable
convert_sets_to_lists(post_to_cpp_code)
return {
"cppCodeToPost": triton_kernel_to_post_grad_json, "cppCodeToPost": triton_kernel_to_post_grad_json,
"postToCppCode": post_to_cpp_code, "postToCppCode": post_to_cpp_code,
} }
@ -810,37 +841,38 @@ def create_node_mapping(
# Since this is just logging code, it should never interfere with regular # Since this is just logging code, it should never interfere with regular
# program execution, so we use this try-except to guard against any error # program execution, so we use this try-except to guard against any error
log.error("Unexpected error in create_node_mapping: %s", e) log.error("Unexpected error in create_node_mapping: %s", e)
log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json)
log.error( log.error(
"triton_kernel_to_post_grad_json: %s", triton_kernel_to_post_grad_json "triton_kernel_to_post_grad_json: %s", triton_kernel_to_post_grad_json
) )
log.error("pre_grad_graph_id: %s", pre_grad_graph_id)
log.error(traceback.format_exc()) log.error(traceback.format_exc())
return empty_return return empty_return
def dump_inductor_provenance_info( def dump_inductor_provenance_info(
filename: str = "inductor_generated_kernel_to_post_grad_nodes.json", filename: str = "inductor_generated_kernel_to_post_grad_nodes.json",
) -> tuple[dict[str, list[str]], dict[str, Any]]: ) -> dict[str, Any]:
global _pre_grad_graph_id global _pre_grad_graph_id
global _inductor_post_to_pre_grad_nodes global _inductor_post_to_pre_grad_nodes
global _inductor_triton_kernel_to_post_grad_node_info global _inductor_triton_kernel_to_post_grad_node_info
debug_info = _inductor_triton_kernel_to_post_grad_node_info.copy()
if config.trace.enabled: if config.trace.enabled:
with V.debug.fopen(filename, "w") as fd: with V.debug.fopen(filename, "w") as fd:
log.info("Writing provenance tracing debugging info to %s", fd.name) log.info("Writing provenance tracing debugging info to %s", fd.name)
json.dump(debug_info, fd) json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd)
node_mapping = {} node_mapping = {}
if _pre_grad_graph_id: if _pre_grad_graph_id:
node_mapping = create_node_mapping( node_mapping_kernel = create_node_mapping_kernel_to_post_grad(
_pre_grad_graph_id, _inductor_post_to_pre_grad_nodes, debug_info _inductor_triton_kernel_to_post_grad_node_info
) )
node_mapping = {
**_inductor_post_to_pre_grad_nodes,
**node_mapping_kernel,
}
if config.trace.enabled: if config.trace.enabled:
with V.debug.fopen( with V.debug.fopen(
"inductor_provenance_tracking_node_mappings.json", "w" "inductor_provenance_tracking_node_mappings.json", "w"
) as fd: ) as fd:
json.dump(node_mapping, fd) json.dump(node_mapping, fd)
return debug_info, node_mapping return node_mapping
def set_kernel_post_grad_provenance_tracing( def set_kernel_post_grad_provenance_tracing(

View File

@ -541,12 +541,23 @@ def get_symbolic_inputs(inputs: Sequence[IRNode]) -> list[Expr]:
class IRNode: class IRNode:
"""Base class for all intermediate representation (IR) nodes in TorchInductor.
Note:
This is an abstract base class. Most methods raise NotImplementedError
and must be overridden by concrete subclasses.
"""
_current_origins: ClassVar[OrderedSet[Any]] = OrderedSet() _current_origins: ClassVar[OrderedSet[Any]] = OrderedSet()
# NB: These are kinda weird, # NB: These are kinda weird,
origins: OrderedSet[Any] = dataclasses.field(init=False) origins: OrderedSet[Any] = dataclasses.field(init=False)
# traces back to where the IRNode is created in Inductor
traceback: Optional[list[str]] = dataclasses.field(init=False) traceback: Optional[list[str]] = dataclasses.field(init=False)
origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False) origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False)
# trace backs to user model code
# a single IRNode could correspond to multiple lines of code
stack_traces: dict[str, str] = dataclasses.field(init=False)
@staticmethod @staticmethod
@contextlib.contextmanager @contextlib.contextmanager
@ -578,12 +589,41 @@ class IRNode:
object.__setattr__(self, attr, value) object.__setattr__(self, attr, value)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self._post_init_setattr("origins", OrderedSet(self._current_origins)) origins = OrderedSet(self._current_origins)
self._post_init_setattr("origins", origins)
self._post_init_setattr( self._post_init_setattr(
"traceback", traceback.format_stack() if config.debug_ir_traceback else None "traceback", traceback.format_stack() if config.debug_ir_traceback else None
) )
self._post_init_setattr("origin_node", None) self._post_init_setattr("origin_node", None)
# Group nodes by their stack traces to deduplicate
nodes_to_stack_trace = {}
if config.trace.provenance_tracking:
for node in origins:
if node.stack_trace:
# nodes in the backward graph don't have mapping to pre_grad_graph
nodes_to_stack_trace["post_grad+" + node.name] = node.stack_trace
else:
if (
"postToPre"
not in torch._inductor.debug._inductor_post_to_pre_grad_nodes
):
continue
node_names = torch._inductor.debug._inductor_post_to_pre_grad_nodes[
"postToPre"
].get(node.name, None)
if node_names:
for node_name in node_names:
stack_trace = torch._inductor.debug._inductor_pre_grad_node_stack_trace.get(
node_name, None
)
if stack_trace:
nodes_to_stack_trace["pre_grad+" + node_name] = (
stack_trace
)
self._post_init_setattr("stack_traces", nodes_to_stack_trace)
def get_read_names(self) -> OrderedSet[str]: def get_read_names(self) -> OrderedSet[str]:
return OrderedSet(dep.name for dep in self.get_reads()) return OrderedSet(dep.name for dep in self.get_reads())
@ -601,7 +641,15 @@ class IRNode:
if shorten and len(origins) > 64: if shorten and len(origins) > 64:
# this can get *very* long # this can get *very* long
origins = f"{origins[:61]}..." origins = f"{origins[:61]}..."
return [origins] if not self.stack_traces:
return [origins]
stack_trace_str = []
for stack_trace in self.stack_traces.values():
stack_trace_str.append("stack_traces = {{")
stack_trace_str += stack_trace.split("\n")
stack_trace_str.append("}")
return [origins] + stack_trace_str
def str_helper( def str_helper(
self, lines: Sequence[object], shorten: bool = True, multiline: bool = True self, lines: Sequence[object], shorten: bool = True, multiline: bool = True