mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fx] Filter stacktrace (#151029)
Filtering out the stacktrace so that the stacktrace on nodes when using fx.Tracer looks nicer. I just copied the filtering we have in [proxy_tensor.py](6720d23969/torch/fx/experimental/proxy_tensor.py (L1903-L1931)).
Previously the stacktrace looked like:
```
File "/data/users/angelayi/pytorch/moo.py", line 3964, in <module>
run_tests()
File "/data/users/angelayi/pytorch/torch/testing/_internal/common_utils.py", line 1342, in run_tests
unittest.main(argv=argv)
File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/main.py", line 101, in __init__
self.runTests()
File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/main.py", line 271, in runTests
self.result = testRunner.run(self.test)
File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/runner.py", line 184, in run
test(result)
File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 122, in run
test(result)
File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/suite.py", line 122, in run
test(result)
File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 650, in __call__
return self.run(*args, **kwds)
File "/data/users/angelayi/pytorch/torch/testing/_internal/common_utils.py", line 3324, in run
self._run_custom(
File "/data/users/angelayi/pytorch/torch/testing/_internal/common_utils.py", line 3296, in _run_custom
super_run(result=result)
File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 591, in run
self._callTestMethod(testMethod)
File "/home/angelayi/.conda/envs/pytorch-3.10/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
method()
File "/data/users/angelayi/pytorch/torch/testing/_internal/common_utils.py", line 3156, in wrapper
method(*args, **kwargs)
File "/data/users/angelayi/pytorch/moo.py", line 1495, in test_stack_trace
gm = torch.fx.GraphModule(m, tracer.trace(m))
File "/data/users/angelayi/pytorch/torch/fx/_symbolic_trace.py", line 837, in trace
(self.create_arg(fn(*args)),),
File "/data/users/angelayi/pytorch/moo.py", line 1485, in forward
x = x * 2
File "/data/users/angelayi/pytorch/torch/fx/proxy.py", line 716, in impl
return tracer.create_proxy("call_function", target, args, kwargs)
File "/data/users/angelayi/pytorch/torch/fx/proxy.py", line 248, in create_proxy
proxy.node.stack_trace = "".join(CapturedTraceback.extract().format())
```
Now it looks like:
```
File "/data/users/angelayi/pytorch/moo.py", line 1485, in forward
x = x * 2
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151029
Approved by: https://github.com/jfix71, https://github.com/zou3519, https://github.com/jingsh
This commit is contained in:
parent
a7ccd96bbf
commit
f4ac9a160d
|
|
@ -6010,7 +6010,7 @@ class TestAOTModuleSimplified(AOTTestCase):
|
|||
mod = torch.fx.GraphModule(tracer.root, graph)
|
||||
|
||||
for node in mod.graph.nodes:
|
||||
if node.op == "output":
|
||||
if node.op != "call_function":
|
||||
continue
|
||||
self.assertTrue(node.stack_trace is not None)
|
||||
assert "test_aotdispatch.py" in node.stack_trace
|
||||
|
|
@ -6049,7 +6049,7 @@ class TestAOTModuleSimplified(AOTTestCase):
|
|||
mod = torch.fx.GraphModule(tracer.root, graph)
|
||||
|
||||
for node in mod.graph.nodes:
|
||||
if node.op == "output":
|
||||
if node.op != "call_function":
|
||||
continue
|
||||
self.assertTrue(node.stack_trace is not None)
|
||||
assert "test_aotdispatch.py" in node.stack_trace
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# Owner(s): ["module: fx"]
|
||||
# ruff: noqa: F841
|
||||
# flake8: noqa: E221
|
||||
|
||||
import builtins
|
||||
import collections
|
||||
|
|
@ -708,26 +709,33 @@ class TestFX(JitTestCase):
|
|||
seen_names.add(node.name)
|
||||
|
||||
def test_stack_traces(self):
|
||||
def foo(a, b):
|
||||
return a * b
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, a, b):
|
||||
return a + b
|
||||
c = a + b
|
||||
c = foo(a, c)
|
||||
return c
|
||||
|
||||
tracer = torch.fx.Tracer()
|
||||
tracer.record_stack_traces = True
|
||||
|
||||
graph = tracer.trace(M())
|
||||
# saving the original list because we will insert new nodes as a part of a test
|
||||
orig_graph_nodes = list(graph.nodes)
|
||||
for node in orig_graph_nodes:
|
||||
if node.op == "output":
|
||||
continue
|
||||
self.assertTrue(node.stack_trace is not None)
|
||||
assert "test_fx.py" in node.stack_trace
|
||||
|
||||
# verify that copying the node does not lose the stack trace
|
||||
new_node = graph.node_copy(node)
|
||||
self.assertTrue(new_node.stack_trace is not None)
|
||||
assert "test_fx.py" in new_node.stack_trace
|
||||
stack_traces = "\n".join([node.meta.get("stack_trace", "") for node in graph.nodes])
|
||||
FileCheck().check_count(
|
||||
"c = a + b", 1, exactly=True
|
||||
).run(stack_traces.strip())
|
||||
FileCheck().check_count(
|
||||
"c = foo(a, c)", 1, exactly=True
|
||||
).run(stack_traces.strip())
|
||||
FileCheck().check_count(
|
||||
"return a * b", 1, exactly=True
|
||||
).run(stack_traces.strip())
|
||||
|
||||
def test_stack_traces_with_transformer(self):
|
||||
class M(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -303,6 +303,8 @@ def uninteresting_files() -> set[str]:
|
|||
torch.fx.experimental.recording,
|
||||
torch.fx.experimental.sym_node,
|
||||
torch.fx.interpreter,
|
||||
torch.fx.proxy,
|
||||
torch.fx._symbolic_trace,
|
||||
torch,
|
||||
torch._compile,
|
||||
torch._dynamo.eval_frame,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import inspect
|
|||
import logging
|
||||
import operator
|
||||
import sys
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import fields, is_dataclass
|
||||
|
|
@ -245,7 +246,29 @@ class TracerBase:
|
|||
proxy = proxy_factory_fn(node)
|
||||
|
||||
if self.record_stack_traces and not proxy.node.stack_trace:
|
||||
proxy.node.stack_trace = "".join(CapturedTraceback.extract().format())
|
||||
from torch.fx.experimental.symbolic_shapes import uninteresting_files
|
||||
|
||||
user_frame_summary = CapturedTraceback.extract().summary()
|
||||
if user_frame_summary:
|
||||
first_forward = -1
|
||||
for i, frame in enumerate(user_frame_summary):
|
||||
if frame.name == "forward":
|
||||
user_frame_summary = user_frame_summary[i:]
|
||||
first_forward = i
|
||||
break
|
||||
|
||||
# Not having a "forward" call in the stacktrace implies the
|
||||
# stacktrace will probably be irrelevant
|
||||
if first_forward == -1:
|
||||
user_frame_summary = []
|
||||
|
||||
stack_trace = [
|
||||
frame
|
||||
for frame in user_frame_summary
|
||||
if frame.filename not in uninteresting_files()
|
||||
]
|
||||
stack_trace = traceback.StackSummary.from_list(stack_trace)
|
||||
proxy.node.stack_trace = "".join(stack_trace.format()).strip()
|
||||
|
||||
return proxy
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user