[Profiler] Fix unexpected C return events (#159574)

The fix in https://github.com/pytorch/pytorch/pull/155446 addressed the "stack empty" issue that's easily reproducible on CPython 3.12.0-4. While this issue can also appear in other versions, it's not as easy to reproduce there.

I recently found a new cause for this problem.

1df5d00145/Python/ceval.c (L5807-L5836)

In the CPython 3.10 implementation, PyTrace_C_CALL and PyTrace_C_RETURN/PyTrace_C_EXCEPTION are supposed to appear in pairs. However, when c_profilefunc is changed, unexpected PyTrace_C_RETURN/PyTrace_C_EXCEPTION events can occur.

Here is the code to reproduce this problem.

```
import threading
import time
import torch

from threading import Event, Lock

lock = Lock()
lock.acquire()

event1 = Event()
event2 = Event()
event3 = Event()

def run():
    event1.set()
    event2.wait()
    lock.acquire()
    event3.set()

threading.Thread(target=run).start()

with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU], with_stack=True):
    event1.wait()
    event2.set()
    time.sleep(1)

with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU], with_stack=True):
    lock.release()
    event3.wait()
```

<img width="1766" height="1250" alt="image" src="https://github.com/user-attachments/assets/6794eeca-7364-429e-91eb-62cdad116bd3" />

To fix this problem, we can record active_frames_ and remaining_start_frames_ for each thread, and when the PyTrace_C-RETURN/PyTrace_CEXT CEPTION event occurs, we can determine whether to record this event based on these two fields.

In reality, even without this fix, the final data appears to be right since the match process can handle this case (it would just result in an exception log being printed).

Do you think the fix is necessary?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159574
Approved by: https://github.com/sraikund16
This commit is contained in:
Denghui Dong 2025-08-07 01:17:52 +00:00 committed by PyTorch MergeBot
parent 5cedc5a0ff
commit 8b0be7b65a
2 changed files with 61 additions and 3 deletions

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: profiler"]
import json
import subprocess
import sys
import time
@ -63,6 +64,46 @@ class TestPythonTracer(TestCase):
name = monitoring.get_tool(2)
self.assertEqual(name, None)
def test_unexpected_c_return_events(self):
code = """
import threading
import time
import torch
from threading import Event, Lock
lock = Lock()
lock.acquire()
event1 = Event()
event2 = Event()
event3 = Event()
def run():
event1.set()
event2.wait()
lock.acquire()
event3.set()
threading.Thread(target=run).start()
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU], with_stack=True):
event1.wait()
event2.set()
time.sleep(1)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU], with_stack=True):
lock.release()
event3.wait()
"""
result = subprocess.run(
[sys.executable, "-c", code], capture_output=True, text=True, check=True
)
self.assertFalse(
"Python replay stack is empty during pop operation" in result.stderr
)
if __name__ == "__main__":
run_tests()

View File

@ -674,6 +674,9 @@ struct ThreadLocalResults {
CallTypeHelper<TraceKeyCacheState>::tuple_type trace_keys_;
AppendOnlyList<c10::approx_time_t, BLOCK_SIZE> exit_times_;
AppendOnlyList<c10::approx_time_t, BLOCK_SIZE> c_exit_times_;
int active_frames_{0};
int remaining_start_frames_{0};
};
// ============================================================================
@ -999,7 +1002,8 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
PyThreadState_Swap(thread_state);
thread_local_results_.emplace_back(thread_state, &value_cache_, this);
auto* ctx = thread_local_results_.back().ctx_;
auto& tls = thread_local_results_.back();
auto* ctx = tls.ctx_;
// When we begin profiling there are already frames on the Python
// interpreter stack. To ensure a complete trace, we must push calls
@ -1021,7 +1025,7 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
}
for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) {
recordPyCall(thread_local_results_.back(), it->get(), true);
recordPyCall(tls, it->get(), true);
auto frame_refcount = Py_REFCNT(it->get());
// We hold one reference in `current_stack`, and the interpreter holds
@ -1029,6 +1033,8 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
TORCH_INTERNAL_ASSERT(frame_refcount >= 2, frame_refcount);
}
tls.remaining_start_frames_ = tls.active_frames_;
// Note:
// This profile will not compose with other CPython profilers, and
// cannot be round tripped via `sys.settrace(sys.gettrace())`
@ -1141,6 +1147,7 @@ void PythonTracer::recordPyCall(
const auto time = c10::getApproximateTime();
is_startup_frame ? start_frames_.push_back({key, time})
: queue_->getSubqueue()->emplace_py_call(key, time);
++tls.active_frames_;
}
void PythonTracer::recordCCall(
@ -1160,6 +1167,7 @@ void PythonTracer::recordCCall(
auto key = tls.intern<CallType::PyCCall, EventType::PyCCall>(
arg, (void*)(fn->m_ml), frame);
queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime());
++tls.active_frames_;
}
// ============================================================================
@ -1457,11 +1465,20 @@ int PythonTracer::pyProfileFn(
case PyTrace_RETURN:
local_results.exit_times_.emplace_back(c10::getApproximateTime());
local_results.active_frames_--;
if (local_results.active_frames_ <
local_results.remaining_start_frames_) {
local_results.remaining_start_frames_ = local_results.active_frames_;
}
break;
case PyTrace_C_EXCEPTION:
case PyTrace_C_RETURN:
if (local_results.active_frames_ >
local_results.remaining_start_frames_) {
local_results.c_exit_times_.emplace_back(c10::getApproximateTime());
local_results.active_frames_--;
}
break;
}
return 0;