mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[retry-land][pytorch][dynamo_compile] Log stack_trace to dynamo_compile (#160348)
refer: https://github.com/pytorch/pytorch/pull/159655 Earlier pr failed on dynamo/test_utils.py::TestDynamoTimed::test_dynamo_timed. Updated test_dynamo_timed + re-ran locally to test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160348 Approved by: https://github.com/masnesral
This commit is contained in:
parent
01bcf9a40d
commit
9a0f7a3bb0
|
|
@ -246,6 +246,32 @@ class TestDynamoTimed(TestCase):
|
|||
utils.reset_frame_count()
|
||||
torch._logging._internal.structured_logging_overhead.clear()
|
||||
|
||||
@dynamo_config.patch({"log_compilation_metrics": True})
|
||||
@inductor_config.patch({"force_disable_caches": True})
|
||||
def test_stack_trace(self):
|
||||
self.warmup()
|
||||
|
||||
compilation_events = []
|
||||
with mock.patch("torch._dynamo.utils.log_compilation_event") as log_event:
|
||||
self.run_forward_backward()
|
||||
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
|
||||
stack_trace_list = []
|
||||
for e in compilation_events:
|
||||
stack_trace_list.append(e.stack_trace)
|
||||
|
||||
self.assertGreater(len(stack_trace_list), 0)
|
||||
result = "\n".join(
|
||||
item
|
||||
for sublist in stack_trace_list
|
||||
if sublist
|
||||
for item in (sublist if isinstance(sublist, list) else [sublist])
|
||||
)
|
||||
self.assertIn(
|
||||
"test_stack_trace",
|
||||
result,
|
||||
"Log file does not contain the expected string: 'test_stack_trace'",
|
||||
)
|
||||
|
||||
@dynamo_config.patch(
|
||||
{
|
||||
"log_compilation_metrics": True,
|
||||
|
|
@ -396,6 +422,7 @@ class TestDynamoTimed(TestCase):
|
|||
e.cuda_version = None
|
||||
e.triton_version = None
|
||||
e.python_version = None
|
||||
e.stack_trace = None
|
||||
|
||||
# First event is for the forward. Formatting makes reading diffs
|
||||
# much easier.
|
||||
|
|
@ -479,6 +506,7 @@ class TestDynamoTimed(TestCase):
|
|||
'runtime_triton_autotune_time_us': None,
|
||||
'shape_env_guard_count': 0,
|
||||
'specialize_float': False,
|
||||
'stack_trace': None,
|
||||
'start_time': 0.0001,
|
||||
'start_time_us': 100,
|
||||
'structured_logging_overhead_s': 0.0,
|
||||
|
|
@ -560,6 +588,7 @@ class TestDynamoTimed(TestCase):
|
|||
'runtime_triton_autotune_time_us': None,
|
||||
'shape_env_guard_count': 0,
|
||||
'specialize_float': False,
|
||||
'stack_trace': None,
|
||||
'start_time': 0.0001,
|
||||
'start_time_us': 100,
|
||||
'structured_logging_overhead_s': 0.0,
|
||||
|
|
@ -652,6 +681,7 @@ class TestDynamoTimed(TestCase):
|
|||
'runtime_triton_autotune_time_us': None,
|
||||
'shape_env_guard_count': None,
|
||||
'specialize_float': None,
|
||||
'stack_trace': None,
|
||||
'start_time': 0.0001,
|
||||
'start_time_us': 100,
|
||||
'structured_logging_overhead_s': 0.0,
|
||||
|
|
@ -733,6 +763,7 @@ class TestDynamoTimed(TestCase):
|
|||
'runtime_triton_autotune_time_us': None,
|
||||
'shape_env_guard_count': None,
|
||||
'specialize_float': None,
|
||||
'stack_trace': None,
|
||||
'start_time': 0.0001,
|
||||
'start_time_us': 100,
|
||||
'structured_logging_overhead_s': 0.0,
|
||||
|
|
|
|||
|
|
@ -225,30 +225,35 @@ def fx_forward_from_src_skip_result(
|
|||
return result
|
||||
|
||||
|
||||
def log_dynamo_start(code: CodeType, skip: int = 0) -> None:
|
||||
def log_dynamo_start(code: CodeType, skip: int = 0) -> list[str]:
|
||||
convert_frame_intern = structured.intern_string(__file__)
|
||||
# Extract and filter the stack
|
||||
stack = list(
|
||||
itertools.takewhile(
|
||||
lambda f: f["filename"] != convert_frame_intern,
|
||||
structured.from_traceback(
|
||||
CapturedTraceback.extract(skip=4 + skip).summary()
|
||||
),
|
||||
)
|
||||
) + [
|
||||
{
|
||||
"line": code.co_firstlineno,
|
||||
"name": code.co_name,
|
||||
"filename": structured.intern_string(code.co_filename),
|
||||
}
|
||||
]
|
||||
# Initialize the ChromiumEventLogger on start
|
||||
torch._logging.trace_structured(
|
||||
"dynamo_start",
|
||||
lambda: {
|
||||
"stack": list(
|
||||
itertools.takewhile(
|
||||
lambda f: f["filename"] != convert_frame_intern,
|
||||
structured.from_traceback(
|
||||
CapturedTraceback.extract(skip=4 + skip).summary()
|
||||
),
|
||||
)
|
||||
)
|
||||
+ [
|
||||
{
|
||||
"line": code.co_firstlineno,
|
||||
"name": code.co_name,
|
||||
"filename": structured.intern_string(code.co_filename),
|
||||
}
|
||||
]
|
||||
},
|
||||
lambda: {"stack": stack},
|
||||
)
|
||||
|
||||
stack_strings = [
|
||||
f"Line: {frame['line']}, Name: {frame['name']}, Filename: {frame['filename']}"
|
||||
for frame in stack
|
||||
]
|
||||
return stack_strings
|
||||
|
||||
|
||||
def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
"""
|
||||
|
|
@ -1160,7 +1165,7 @@ def _compile(
|
|||
# # 2 extra here
|
||||
# torch/_logging/_internal.py:1064 in trace_structured
|
||||
# torch/_dynamo/convert_frame.py:780 in <lambda>
|
||||
log_dynamo_start(code, skip)
|
||||
stack_trace = log_dynamo_start(code, skip)
|
||||
start_time_ns = time.time_ns()
|
||||
fail_type: Optional[str] = None
|
||||
fail_reason: Optional[str] = None
|
||||
|
|
@ -1300,6 +1305,7 @@ def _compile(
|
|||
"dynamo_compile_time_before_restart_us": to_int_us(
|
||||
dynamo_time_before_restart
|
||||
),
|
||||
"stack_trace": stack_trace,
|
||||
}
|
||||
# TODO: replace with CompileEventLogger.compilation_metrics
|
||||
# There are some columns here not in PT2 Compile Events
|
||||
|
|
|
|||
|
|
@ -1288,6 +1288,7 @@ class CompilationMetrics:
|
|||
compliant_custom_ops: Optional[set[str]] = None
|
||||
restart_reasons: Optional[set[str]] = None
|
||||
dynamo_time_before_restart_s: Optional[float] = None
|
||||
stack_trace: Optional[list[str]] = None
|
||||
# Sometimes, we will finish analyzing a frame but conclude we don't want
|
||||
# to install any guarded code. True means we actually decided to install
|
||||
# a compiled frame
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user