[export] Improve stacktrace filtering (#141285)

Differential Revision: [D66321127](https://our.internmc.facebook.com/intern/diff/D66321127)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141285
Approved by: https://github.com/yushangdi
ghstack dependencies: #141071, #141072
This commit is contained in:
angelayi 2024-11-21 21:24:13 -08:00 committed by PyTorch MergeBot
parent 53df1c11cd
commit 32583d915e

View File

@ -1,12 +1,12 @@
import inspect
import logging
import sys
import os
from enum import IntEnum
from functools import lru_cache
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch._logging._internal
import torch._logging.structured
from torch._export.passes.insert_custom_op_guards import insert_custom_op_guards
from torch.export import ExportedProgram
from torch.export._trace import _export
@ -26,31 +26,7 @@ class FailureType(IntEnum):
return self.name
@lru_cache
def uninteresting_files() -> Set[str]:
import torch._inductor.sizevars
import torch._subclasses.fake_tensor
import torch._subclasses.meta_utils
mods = [
sys.modules[__name__],
torch.fx.experimental.recording,
torch.fx.experimental.sym_node,
torch.fx.experimental.symbolic_shapes,
torch.fx.interpreter,
torch,
torch._inductor.sizevars,
torch._logging._internal,
torch._subclasses.meta_utils,
torch._subclasses.fake_tensor,
torch._subclasses.functional_tensor,
]
return {inspect.getfile(m) for m in mods}
def prettify_stack(
stack: List[Dict["str", "str"]], str_to_filename: Dict[str, str]
) -> str:
def prettify_stack(stack: List[Dict[str, str]], str_to_filename: Dict[str, str]) -> str:
res = ""
for frame in stack:
if frame["filename"] not in str_to_filename:
@ -68,7 +44,8 @@ def filter_stack(
s["filename"] = str(s["filename"])
if s["filename"] not in str_to_filename:
continue
if str_to_filename[s["filename"]] not in uninteresting_files():
torch_filepath = os.path.dirname(inspect.getfile(torch)) + os.path.sep
if torch_filepath not in str_to_filename[s["filename"]]:
return stack[len(stack) - i - 3 : len(stack) - i]
return stack[-3:]
@ -103,7 +80,9 @@ class FailureReport:
The specified input dynamic_shapes spec was found to be incorrect during tracing.
Specifically, this guard was added: {self.data["expr"]}, where {self.data["symbol_to_sources"]}.
This occured at the following stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}.
Because of this, we have modified the dynamic shapes structure to be the following:
Because of this, we have modified the dynamic shapes structure to be the
following. You can also use torch.export.Dim.AUTO instead to specify your
dynamic shapes, and we will automatically infer the dynamism for you.
```
dynamic_shapes = {self.data["new_dynamic_shapes"]}
```
@ -219,7 +198,6 @@ def draft_export(
capture_structured_log = CaptureStructuredTrace(
[
"str",
"propagate_real_tensors",
"guard_added",
"missing_fake_kernel",
@ -256,7 +234,9 @@ def draft_export(
preserve_module_call_signature=preserve_module_call_signature,
)
str_to_filename: Dict[str, str] = {}
str_to_filename: Dict[str, str] = {
str(v): k for (k, v) in torch._logging.structured.INTERN_TABLE.items()
}
failures: List[FailureReport] = []
custom_ops_logs: Dict[
Any, Tuple[Dict[str, Any], FailureType]
@ -278,11 +258,6 @@ def draft_export(
data_dependent_logs[hash_stack(log_contents["stack"])] = log_contents
failure_type = FailureType.DATA_DEPENDENT_ERROR
elif log_name == "str":
filename, idx = log_contents
str_to_filename[str(idx)] = filename
continue
elif log_name == "guard_added":
if new_shapes is None:
continue