mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Dedup data-dependent errors based on stacktrace (#139540)
Summary: Dedup the data-dependent errors based on the stacktrace it points to. Right now we just display every propagate-real-tensor log that shows up, but we actually can dedup them if they are due to the same piece of code (ex. there could multiple calls to a piece of code that does some data dependent computation). This occurred when trying out draft export on the PT2I model zoo. For a specific model, previously we would get ~3k data dependent errors, but after deduping based on the stacktrace we now only get 4 errors. Test Plan: CI Differential Revision: D65374254 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139540 Approved by: https://github.com/pianpwk, https://github.com/zou3519
This commit is contained in:
parent
cc25b6d7ba
commit
de509abe1c
|
|
@ -142,6 +142,29 @@ class TestDraftExport(TestCase):
|
|||
inp = (torch.randn(3, 3), torch.randn(3, 3), torch.tensor(2))
|
||||
self.assertEqual(ep.module()(*inp), M()(*inp))
|
||||
|
||||
def test_dedup_data_dependent_failure(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y, z):
|
||||
res = 0
|
||||
for v in [x, y]:
|
||||
if v.item() > 10:
|
||||
res += v * v
|
||||
else:
|
||||
res += v + v
|
||||
|
||||
return z * res
|
||||
|
||||
inp = (torch.tensor(5), torch.tensor(3), torch.tensor(2))
|
||||
|
||||
ep, report = draft_export(M(), inp)
|
||||
self.assertTrue(len(report.failures) > 0)
|
||||
self.assertEqual(
|
||||
report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR
|
||||
)
|
||||
|
||||
inp = (torch.tensor(4), torch.tensor(2), torch.tensor(6))
|
||||
self.assertEqual(ep.module()(*inp), M()(*inp))
|
||||
|
||||
def test_offsets(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ def uninteresting_files() -> Set[str]:
|
|||
torch._logging._internal,
|
||||
torch._subclasses.meta_utils,
|
||||
torch._subclasses.fake_tensor,
|
||||
torch._subclasses.functional_tensor,
|
||||
]
|
||||
return {inspect.getfile(m) for m in mods}
|
||||
|
||||
|
|
@ -62,6 +63,7 @@ def filter_stack(
|
|||
stack: List[Dict[str, str]], str_to_filename: Dict[str, str]
|
||||
) -> List[Dict[str, str]]:
|
||||
for i, s in enumerate(reversed(stack)):
|
||||
s["filename"] = str(s["filename"])
|
||||
if s["filename"] not in str_to_filename:
|
||||
continue
|
||||
if str_to_filename[s["filename"]] not in uninteresting_files():
|
||||
|
|
@ -69,6 +71,10 @@ def filter_stack(
|
|||
return stack[-3:]
|
||||
|
||||
|
||||
def hash_stack(stack: List[Dict[str, str]]) -> str:
|
||||
return ";".join(f'line: {s["line"]} filename: {s["filename"]}' for s in stack)
|
||||
|
||||
|
||||
class FailureReport:
|
||||
def __init__(
|
||||
self, failure_type: FailureType, data: Dict[str, Any], xfail: bool = False
|
||||
|
|
@ -234,21 +240,32 @@ def draft_export(
|
|||
str_to_filename: Dict[str, str] = {}
|
||||
failures: List[FailureReport] = []
|
||||
custom_ops_logs: Dict[str, Dict[str, Any]] = {} # Dedup custom ops
|
||||
data_dependent_logs: Dict[
|
||||
str, Dict[str, Any]
|
||||
] = {} # Dedup data dependent errors based on stacktrace
|
||||
|
||||
for log_name, log_contents in capture_structured_log.logs:
|
||||
if log_name == "propagate_real_tensors":
|
||||
failure_type = FailureType.DATA_DEPENDENT_ERROR
|
||||
failure_type = None
|
||||
|
||||
if log_name == "propagate_real_tensors":
|
||||
log_contents["stack"] = filter_stack(
|
||||
log_contents["stack"], str_to_filename
|
||||
)
|
||||
if hash_stack(log_contents["stack"]) in data_dependent_logs:
|
||||
continue
|
||||
|
||||
data_dependent_logs[hash_stack(log_contents["stack"])] = log_contents
|
||||
failure_type = FailureType.DATA_DEPENDENT_ERROR
|
||||
|
||||
elif log_name == "str":
|
||||
filename, idx = log_contents
|
||||
str_to_filename[str(idx)] = filename
|
||||
continue
|
||||
|
||||
elif log_name == "guard_added":
|
||||
if new_shapes is None:
|
||||
continue
|
||||
|
||||
failure_type = FailureType.CONSTRAINT_VIOLATION_ERROR
|
||||
if len(log_contents["symbol_to_sources"]) == 0:
|
||||
# We only want to include guards added that are relevant to
|
||||
|
|
@ -260,12 +277,18 @@ def draft_export(
|
|||
log_contents["stack"], str_to_filename
|
||||
)
|
||||
log_contents["new_dynamic_shapes"] = new_shapes
|
||||
|
||||
elif log_name == "generated_fake_kernel":
|
||||
if log_contents["op"] in custom_ops_logs:
|
||||
continue
|
||||
|
||||
failure_type = FailureType.MISSING_FAKE_KERNEL
|
||||
custom_ops_logs[log_contents["op"]] = log_contents
|
||||
continue
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unknown log name: {log_name}")
|
||||
|
||||
assert failure_type is not None
|
||||
failures.append(
|
||||
FailureReport(
|
||||
failure_type,
|
||||
|
|
@ -273,14 +296,6 @@ def draft_export(
|
|||
)
|
||||
)
|
||||
|
||||
for custom_op_log in custom_ops_logs.values():
|
||||
failures.append(
|
||||
FailureReport(
|
||||
FailureType.MISSING_FAKE_KERNEL,
|
||||
custom_op_log,
|
||||
)
|
||||
)
|
||||
|
||||
report = DraftExportReport(failures, str_to_filename)
|
||||
|
||||
ep._report = report
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user