Add kernel stack traces tlparse dump (#160608) (#160779)

Summary:

as title

This is requested by the zoomer team so they can add stack trace information to profiler result.

Test Plan:
```
buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing -- -r  stack_traces
```

Rollback Plan:

Differential Revision: D80050233

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160779
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu 2025-08-16 03:12:38 +00:00 committed by PyTorch MergeBot
parent b7ca502f29
commit b74c7cd335
3 changed files with 135 additions and 0 deletions

View File

@ -1,5 +1,7 @@
# Owner(s): ["module: inductor"]
import contextlib
import io
import json
import logging
import re
@ -28,6 +30,9 @@ except ImportError:
from test_aot_inductor_utils import AOTIRunnerUtil
trace_log = logging.getLogger("torch.__trace")
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
@ -483,5 +488,105 @@ class TestProvenanceTracingNodeMeta(TestCase):
self.assertEqual(mm_node.meta["stack_trace"], stack_trace)
class ProvenanceArtifactFilter(logging.Filter):
def filter(self, record):
if "artifact" in record.metadata:
return (
record.metadata["artifact"]["name"]
== "inductor_provenance_tracking_kernel_stack_traces"
)
return False
class StructuredTracePayloadFormatter(logging.Formatter):
def format(self, record):
return record.payload.strip()
class TestProvenanceTracingStackTraces(TestCase):
@contextlib.contextmanager
def _setup_provenance_capture(self):
"""Helper to turn on and capture the 'inductor_tlparse_runtime' structured trace."""
payload_buffer = io.StringIO()
payload_handler = logging.StreamHandler(payload_buffer)
payload_handler.setLevel(logging.DEBUG)
payload_handler.setFormatter(StructuredTracePayloadFormatter())
payload_handler.addFilter(ProvenanceArtifactFilter())
trace_log.addHandler(payload_handler)
try:
yield payload_buffer
finally:
trace_log.removeHandler(payload_handler)
def extract_code_line(self, s):
# Extract last non-empty line
return s.split("\n")[-2].strip()
@torch._inductor.config.patch(
{"fx_graph_cache": False, "trace.provenance_tracking_level": 2}
)
@requires_cuda_and_triton
def test_tlparse_kernel_stack_traces(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x, a, b, c):
x = self.fc1(x)
x = self.relu(x)
x = self.sigmoid(x)
d = a * 3.14
y = torch.addmm(c, d, b)
z = torch.nn.functional.gelu(y)
return x, z
device = "cuda"
model = Model().to(device)
x = torch.randn(8, 10).to(device)
a = torch.randn(10, 20).to(device)
b = torch.randn(20, 30).to(device)
c = torch.randn(10, 30).to(device)
example_inputs = (x, a, b, c)
expected = {
"triton_poi_fused_addmm_relu_sigmoid_threshold_backward_0": [
"x = self.sigmoid(x)",
"x = self.fc1(x)",
"x = self.relu(x)",
],
"triton_poi_fused_mul_1": [
"d = a * 3.14",
],
"triton_poi_fused_addmm_gelu_2": [
"z = torch.nn.functional.gelu(y)",
"y = torch.addmm(c, d, b)",
],
"extern_kernels.mm": [
"y = torch.addmm(c, d, b)",
],
}
with self._setup_provenance_capture() as payload_buffer:
compiled = torch.compile(model)
compiled(*example_inputs)
payload_content = payload_buffer.getvalue().strip()
if payload_content:
data = json.loads(payload_content)
self.assertEqual(set(data.keys()), set(expected.keys()))
for key, expected_lines in expected.items():
actual_lines = [self.extract_code_line(s) for s in data[key]]
print(key)
print(actual_lines)
print(expected_lines)
self.assertEqual(
sorted(actual_lines),
sorted(expected_lines),
f"Mismatch for key: {key}",
)
if __name__ == "__main__":
run_tests()

View File

@ -1070,6 +1070,16 @@ def _compile_fx_inner(
torch._inductor.debug.dump_inductor_provenance_info()
),
)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "inductor_provenance_tracking_kernel_stack_traces",
"encoding": "json",
},
payload_fn=lambda: json.dumps(
torch._inductor.debug._inductor_kernel_stack_trace
),
)
# This message is for printing overview information of inductor mm counts, shapes,etc after lowering
if log.isEnabledFor(logging.INFO):

View File

@ -319,6 +319,7 @@ _inductor_post_to_pre_grad_nodes: dict[str, dict[str, list[str]]] = {}
_inductor_triton_kernel_to_post_grad_node_info: dict[str, Any] = {}
_pre_grad_graph_id: Optional[int] = None
_inductor_pre_grad_node_stack_trace: dict[str, str] = {}
_inductor_kernel_stack_trace: dict[str, list[str]] = {}
@contextlib.contextmanager
@ -328,6 +329,8 @@ def reset_provenance_globals() -> Iterator[None]:
global _pre_grad_graph_id
global _inductor_post_to_pre_grad_nodes
global _inductor_triton_kernel_to_post_grad_node_info
global _inductor_pre_grad_node_stack_trace
global _inductor_kernel_stack_trace
# Store original values
original_pre_grad_graph_id = _pre_grad_graph_id
@ -335,11 +338,17 @@ def reset_provenance_globals() -> Iterator[None]:
original_triton_kernel_to_post_grad_node_info = (
_inductor_triton_kernel_to_post_grad_node_info.copy()
)
original_inductor_pre_grad_node_stack_trace = (
_inductor_pre_grad_node_stack_trace.copy()
)
original_inductor_kernel_stack_trace = _inductor_kernel_stack_trace.copy()
# Reset to default values
_pre_grad_graph_id = -1
_inductor_post_to_pre_grad_nodes = {}
_inductor_triton_kernel_to_post_grad_node_info = {}
_inductor_pre_grad_node_stack_trace = {}
_inductor_kernel_stack_trace = {}
try:
yield
@ -350,6 +359,10 @@ def reset_provenance_globals() -> Iterator[None]:
_inductor_triton_kernel_to_post_grad_node_info = (
original_triton_kernel_to_post_grad_node_info
)
_inductor_kernel_stack_trace = original_inductor_kernel_stack_trace
_inductor_pre_grad_node_stack_trace = (
original_inductor_pre_grad_node_stack_trace
)
class DebugContext:
@ -942,6 +955,7 @@ def set_kernel_post_grad_provenance_tracing(
from .codegen.simd_kernel_features import DisableReduction, EnableReduction
global _inductor_triton_kernel_to_post_grad_node_info
global _inductor_kernel_stack_trace
if is_extern:
assert isinstance(node_schedule, ExternKernelOut)
curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault(
@ -960,8 +974,12 @@ def set_kernel_post_grad_provenance_tracing(
for origin in node_schedule.origins
if origin.name not in curr_node_info
)
_inductor_kernel_stack_trace[kernel_name] = list(
node_schedule.get_stack_traces()
)
else:
assert isinstance(node_schedule, list)
stack_traces: OrderedSet[str] = OrderedSet()
for snode in node_schedule:
if snode not in (EnableReduction, DisableReduction):
if snode.node is not None:
@ -970,11 +988,13 @@ def set_kernel_post_grad_provenance_tracing(
kernel_name, []
)
)
stack_traces.update(snode.node.get_stack_traces())
curr_node_info.extend(
origin.name
for origin in snode.node.origins
if origin.name not in curr_node_info
)
_inductor_kernel_stack_trace[kernel_name] = list(stack_traces)
except Exception as e:
# Since this is just debugging, it should never interfere with regular
# program execution, so we use this try-except to guard against any error