[fx] Filter stacktrace (#151029)

Filtering out the stacktrace so that the stacktrace on nodes when using fx.Tracer looks nicer. I just copied the filtering we have in [proxy_tensor.py](6720d23969/torch/fx/experimental/proxy_tensor.py (L1903-L1931)).

Previously the stacktrace looked like:
```
File "/data/users/angelayi/pytorch/moo.py", line 3964, in <module>
    run_tests()
  File "/data/users/angelayi/pytorch/torch/testing/_internal/common_utils.py", line 1342, in run_tests
    unittest.main(argv=argv)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/main.py", line 101, in __init__
    self.runTests()
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/main.py", line 271, in runTests
    self.result = testRunner.run(self.test)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/runner.py", line 184, in run
    test(result)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 650, in __call__
    return self.run(*args, **kwds)
  File "/data/users/angelayi/pytorch/torch/testing/_internal/common_utils.py", line 3324, in run
    self._run_custom(
  File "/data/users/angelayi/pytorch/torch/testing/_internal/common_utils.py", line 3296, in _run_custom
    super_run(result=result)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/data/users/angelayi/pytorch/torch/testing/_internal/common_utils.py", line 3156, in wrapper
    method(*args, **kwargs)
  File "/data/users/angelayi/pytorch/moo.py", line 1495, in test_stack_trace
    gm = torch.fx.GraphModule(m, tracer.trace(m))
  File "/data/users/angelayi/pytorch/torch/fx/_symbolic_trace.py", line 837, in trace
    (self.create_arg(fn(*args)),),
  File "/data/users/angelayi/pytorch/moo.py", line 1485, in forward
    x = x * 2
  File "/data/users/angelayi/pytorch/torch/fx/proxy.py", line 716, in impl
    return tracer.create_proxy("call_function", target, args, kwargs)
  File "/data/users/angelayi/pytorch/torch/fx/proxy.py", line 248, in create_proxy
    proxy.node.stack_trace = "".join(CapturedTraceback.extract().format())
```
Now it looks like:
```
File "/data/users/angelayi/pytorch/moo.py", line 1485, in forward
    x = x * 2
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151029
Approved by: https://github.com/jfix71, https://github.com/zou3519, https://github.com/jingsh
This commit is contained in:
angelayi 2025-04-22 22:50:32 +00:00 committed by PyTorch MergeBot
parent a7ccd96bbf
commit f4ac9a160d
4 changed files with 48 additions and 15 deletions

View File

@ -6010,7 +6010,7 @@ class TestAOTModuleSimplified(AOTTestCase):
mod = torch.fx.GraphModule(tracer.root, graph)
for node in mod.graph.nodes:
if node.op == "output":
if node.op != "call_function":
continue
self.assertTrue(node.stack_trace is not None)
assert "test_aotdispatch.py" in node.stack_trace
@ -6049,7 +6049,7 @@ class TestAOTModuleSimplified(AOTTestCase):
mod = torch.fx.GraphModule(tracer.root, graph)
for node in mod.graph.nodes:
if node.op == "output":
if node.op != "call_function":
continue
self.assertTrue(node.stack_trace is not None)
assert "test_aotdispatch.py" in node.stack_trace

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: fx"]
# ruff: noqa: F841
# flake8: noqa: E221
import builtins
import collections
@ -708,26 +709,33 @@ class TestFX(JitTestCase):
seen_names.add(node.name)
def test_stack_traces(self):
def foo(a, b):
return a * b
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b):
return a + b
c = a + b
c = foo(a, c)
return c
tracer = torch.fx.Tracer()
tracer.record_stack_traces = True
graph = tracer.trace(M())
# saving the original list because we will insert new nodes as a part of a test
orig_graph_nodes = list(graph.nodes)
for node in orig_graph_nodes:
if node.op == "output":
continue
self.assertTrue(node.stack_trace is not None)
assert "test_fx.py" in node.stack_trace
# verify that copying the node does not lose the stack trace
new_node = graph.node_copy(node)
self.assertTrue(new_node.stack_trace is not None)
assert "test_fx.py" in new_node.stack_trace
stack_traces = "\n".join([node.meta.get("stack_trace", "") for node in graph.nodes])
FileCheck().check_count(
"c = a + b", 1, exactly=True
).run(stack_traces.strip())
FileCheck().check_count(
"c = foo(a, c)", 1, exactly=True
).run(stack_traces.strip())
FileCheck().check_count(
"return a * b", 1, exactly=True
).run(stack_traces.strip())
def test_stack_traces_with_transformer(self):
class M(torch.nn.Module):

View File

@ -303,6 +303,8 @@ def uninteresting_files() -> set[str]:
torch.fx.experimental.recording,
torch.fx.experimental.sym_node,
torch.fx.interpreter,
torch.fx.proxy,
torch.fx._symbolic_trace,
torch,
torch._compile,
torch._dynamo.eval_frame,

View File

@ -8,6 +8,7 @@ import inspect
import logging
import operator
import sys
import traceback
from collections import OrderedDict
from collections.abc import Iterator
from dataclasses import fields, is_dataclass
@ -245,7 +246,29 @@ class TracerBase:
proxy = proxy_factory_fn(node)
if self.record_stack_traces and not proxy.node.stack_trace:
proxy.node.stack_trace = "".join(CapturedTraceback.extract().format())
from torch.fx.experimental.symbolic_shapes import uninteresting_files
user_frame_summary = CapturedTraceback.extract().summary()
if user_frame_summary:
first_forward = -1
for i, frame in enumerate(user_frame_summary):
if frame.name == "forward":
user_frame_summary = user_frame_summary[i:]
first_forward = i
break
# Not having a "forward" call in the stacktrace implies the
# stacktrace will probably be irrelevant
if first_forward == -1:
user_frame_summary = []
stack_trace = [
frame
for frame in user_frame_summary
if frame.filename not in uninteresting_files()
]
stack_trace = traceback.StackSummary.from_list(stack_trace)
proxy.node.stack_trace = "".join(stack_trace.format()).strip()
return proxy