[export] Dedup data-dependent errors based on stacktrace (#139540)

Summary:
Dedup the data-dependent errors based on the stacktrace it points to. Right now we just display every propagate-real-tensor log that shows up, but we actually can dedup them if they are due to the same piece of code (ex. there could multiple calls to a piece of code that does some data dependent computation).

This occurred when trying out draft export on the PT2I model zoo. For a specific model, previously we would get ~3k data dependent errors, but after deduping based on the stacktrace we now only get 4 errors.

Test Plan: CI

Differential Revision: D65374254

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139540
Approved by: https://github.com/pianpwk, https://github.com/zou3519
This commit is contained in:
Angela Yi 2024-11-05 18:16:03 +00:00 committed by PyTorch MergeBot
parent cc25b6d7ba
commit de509abe1c
2 changed files with 49 additions and 11 deletions

View File

@ -142,6 +142,29 @@ class TestDraftExport(TestCase):
inp = (torch.randn(3, 3), torch.randn(3, 3), torch.tensor(2))
self.assertEqual(ep.module()(*inp), M()(*inp))
def test_dedup_data_dependent_failure(self):
class M(torch.nn.Module):
def forward(self, x, y, z):
res = 0
for v in [x, y]:
if v.item() > 10:
res += v * v
else:
res += v + v
return z * res
inp = (torch.tensor(5), torch.tensor(3), torch.tensor(2))
ep, report = draft_export(M(), inp)
self.assertTrue(len(report.failures) > 0)
self.assertEqual(
report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR
)
inp = (torch.tensor(4), torch.tensor(2), torch.tensor(6))
self.assertEqual(ep.module()(*inp), M()(*inp))
def test_offsets(self):
class M(torch.nn.Module):
def forward(self, x):

View File

@ -41,6 +41,7 @@ def uninteresting_files() -> Set[str]:
torch._logging._internal,
torch._subclasses.meta_utils,
torch._subclasses.fake_tensor,
torch._subclasses.functional_tensor,
]
return {inspect.getfile(m) for m in mods}
@ -62,6 +63,7 @@ def filter_stack(
stack: List[Dict[str, str]], str_to_filename: Dict[str, str]
) -> List[Dict[str, str]]:
for i, s in enumerate(reversed(stack)):
s["filename"] = str(s["filename"])
if s["filename"] not in str_to_filename:
continue
if str_to_filename[s["filename"]] not in uninteresting_files():
@ -69,6 +71,10 @@ def filter_stack(
return stack[-3:]
def hash_stack(stack: List[Dict[str, str]]) -> str:
return ";".join(f'line: {s["line"]} filename: {s["filename"]}' for s in stack)
class FailureReport:
def __init__(
self, failure_type: FailureType, data: Dict[str, Any], xfail: bool = False
@ -234,21 +240,32 @@ def draft_export(
str_to_filename: Dict[str, str] = {}
failures: List[FailureReport] = []
custom_ops_logs: Dict[str, Dict[str, Any]] = {} # Dedup custom ops
data_dependent_logs: Dict[
str, Dict[str, Any]
] = {} # Dedup data dependent errors based on stacktrace
for log_name, log_contents in capture_structured_log.logs:
if log_name == "propagate_real_tensors":
failure_type = FailureType.DATA_DEPENDENT_ERROR
failure_type = None
if log_name == "propagate_real_tensors":
log_contents["stack"] = filter_stack(
log_contents["stack"], str_to_filename
)
if hash_stack(log_contents["stack"]) in data_dependent_logs:
continue
data_dependent_logs[hash_stack(log_contents["stack"])] = log_contents
failure_type = FailureType.DATA_DEPENDENT_ERROR
elif log_name == "str":
filename, idx = log_contents
str_to_filename[str(idx)] = filename
continue
elif log_name == "guard_added":
if new_shapes is None:
continue
failure_type = FailureType.CONSTRAINT_VIOLATION_ERROR
if len(log_contents["symbol_to_sources"]) == 0:
# We only want to include guards added that are relevant to
@ -260,12 +277,18 @@ def draft_export(
log_contents["stack"], str_to_filename
)
log_contents["new_dynamic_shapes"] = new_shapes
elif log_name == "generated_fake_kernel":
if log_contents["op"] in custom_ops_logs:
continue
failure_type = FailureType.MISSING_FAKE_KERNEL
custom_ops_logs[log_contents["op"]] = log_contents
continue
else:
raise RuntimeError(f"Unknown log name: {log_name}")
assert failure_type is not None
failures.append(
FailureReport(
failure_type,
@ -273,14 +296,6 @@ def draft_export(
)
)
for custom_op_log in custom_ops_logs.values():
failures.append(
FailureReport(
FailureType.MISSING_FAKE_KERNEL,
custom_op_log,
)
)
report = DraftExportReport(failures, str_to_filename)
ep._report = report