Fix PT2 Source Code Annotations (#136460)

Summary: In D60803317, we added CompileContext (trace_id) information to Kineto traces using caching when a CompileContext exits. As pointed out by some users, this gives innaccurate IDs because we are not getting the context that we is being looked up within the eval_frame. For this reason, we decided to revert that change, and go with an approach that involves getting the trace_id associated with a given CacheEntry. To do this, we add a trace_id to the GuardedCode so that it can be passed onto a CacheEntry. Then, we change the lookup function to return said trace_id alongside the code so that we can pass both into our eval function. Once we get to a Torch-Compiled Region, we can just append the context information to the name of the annotation thus bypassing any need for kwargs.

Test Plan: Added more comprehensive unit test. Saw that all the trace_ids appeared within the graph.

Differential Revision: D63138786

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136460
Approved by: https://github.com/ezyang
This commit is contained in:
Shivam Raikundalia 2024-09-28 03:54:43 +00:00 committed by PyTorch MergeBot
parent 8df97d78c2
commit 9e4f24f8e5
9 changed files with 76 additions and 25 deletions

View File

@ -163,20 +163,34 @@ class DynamoProfilerTests(torch._dynamo.test_case.TestCase):
)
def test_profiler_dynamo_compiled_region(self):
def fn(x, y, z):
return x @ y + z
opt_fn = torch._dynamo.optimize("eager")(fn)
inputs = [torch.rand(4, 4) for _ in range(3)]
for _ in range(2):
opt_fn(*inputs)
def fn(x, y):
r = y.sum(dim=1)
print(r.shape)
return x * r
with torch.profiler.profile() as prof:
opt_fn(*inputs)
fn_c = torch.compile(fn)
self.assertTrue(any(e.name == "Torch-Compiled Region" for e in prof.events()))
fn_c(
torch.randn(10),
torch.randn(10, 10),
)
fn_c(
torch.randn(10),
torch.randn(10, 15),
)
annotations = [e.name for e in prof.events() if "Compiled" in e.name]
self.assertEqual(
annotations,
[
"Torch-Compiled Region: 0/0",
"Torch-Compiled Region: 1/0",
"Torch-Compiled Region: 0/1",
"Torch-Compiled Region: 1/0",
],
)
if __name__ == "__main__":

View File

@ -804,7 +804,11 @@ def _compile(
hooks.guard_fail_fn if hooks else None,
)
guarded_code = GuardedCode(out_code, check_fn.check_fn, compile_id)
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
annotation_str = "Torch-Compiled Region: " + compile_id_str
guarded_code = GuardedCode(
out_code, check_fn.check_fn, compile_id, annotation_str
)
if not output.is_empty_graph() and hooks.guard_export_fn is not None:
# We should not run the guard_export_fn when Dynamo does not

View File

@ -48,6 +48,7 @@ class GuardedCode:
code: types.CodeType
check_fn: GuardFn
compile_id: CompileId
trace_annotation: str = "Unknown"
class DynamoCallbackFn(Protocol):

View File

@ -9,6 +9,13 @@ CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend)
this->check_fn = guarded_code.attr("check_fn");
this->code = guarded_code.attr("code");
this->compile_id = guarded_code.attr("compile_id");
py::object trace_annotation = guarded_code.attr("trace_annotation");
const char* trace_annotation_str = PyUnicode_AsUTF8(trace_annotation.ptr());
if (trace_annotation) {
this->trace_annotation = std::string(trace_annotation_str);
} else {
this->trace_annotation = "Unknown";
}
// TODO - clean this up when enable_cpp_guard_manager is True by default
if (py::hasattr(this->check_fn, "root")) {
this->root_mgr = torch::dynamo::convert_to_root_guard_manager(
@ -42,6 +49,10 @@ PyCodeObject* CacheEntry_get_code(CacheEntry* e) {
return (PyCodeObject*)e->code.ptr();
}
const char* CacheEntry_get_trace_annotation(CacheEntry* e) {
return e->trace_annotation.c_str();
}
PyObject* CacheEntry_to_obj(CacheEntry* e) {
if (!e) {
return py::none().release().ptr();

View File

@ -54,6 +54,8 @@ typedef struct VISIBILITY_HIDDEN CacheEntry {
ExtraState* _owner{nullptr};
// Reference to this CacheEntry's location in owner's linked list
std::list<CacheEntry>::iterator _owner_loc;
// Reference to string representation of the CompileContext
std::string trace_annotation;
CacheEntry(const py::handle& guarded_code, PyObject* backend);
~CacheEntry();
@ -69,6 +71,9 @@ C10_DIAGNOSTIC_POP()
// Returns borrowed reference
PyCodeObject* CacheEntry_get_code(CacheEntry* e);
// Returns borrowed string representation of CompileContext
const char* CacheEntry_get_trace_annotation(CacheEntry* e);
// Returns a borrowed reference to CacheEntry as a PyObject
// Warning: lifetime is controlled by C++
PyObject* CacheEntry_to_obj(CacheEntry* e);

View File

@ -481,9 +481,11 @@ inline static PyObject* eval_custom_code(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
PyCodeObject* code,
const char* trace_annotation,
int throw_flag,
int free_vars_copied) {
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter("Torch-Compiled Region");
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(trace_annotation);
PyObject* result = eval_custom_code_impl(
tstate,
frame,
@ -615,7 +617,9 @@ static PyObject* _custom_eval_frame(
if (callback == Py_False || extra_state_cache_limit_hit(extra)) {
DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
PyObject* maybe_cached_code = lookup(extra, locals, backend);
PyObject* maybe_cached_code = NULL;
const char* trace_annotation = "";
lookup(extra, locals, backend, &maybe_cached_code, &trace_annotation);
_pytorch_record_function_exit(rf);
Py_DECREF(locals);
@ -641,7 +645,7 @@ static PyObject* _custom_eval_frame(
// used cached version
DEBUG_TRACE("cache hit %s", get_frame_name(frame));
*should_clear_frame = 1;
return eval_custom_code(tstate, frame, cached_code, throw_flag, 0);
return eval_custom_code(tstate, frame, cached_code, trace_annotation, throw_flag, 0);
}
DEBUG_CHECK(PyDict_CheckExact(locals));
DEBUG_CHECK(PyDict_CheckExact(frame->f_globals));
@ -653,7 +657,9 @@ static PyObject* _custom_eval_frame(
eval_frame_callback_set(Py_None);
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
PyObject* maybe_cached_code = lookup(extra, locals, backend);
PyObject* maybe_cached_code = NULL;
const char* trace_annotation = "";
lookup(extra, locals, backend, &maybe_cached_code, &trace_annotation);
_pytorch_record_function_exit(rf);
if (maybe_cached_code == NULL) {
// Python error
@ -668,7 +674,7 @@ static PyObject* _custom_eval_frame(
eval_frame_callback_set(callback);
*should_clear_frame = 1;
Py_DECREF(locals);
return eval_custom_code(tstate, frame, cached_code, throw_flag, free_vars_copied);
return eval_custom_code(tstate, frame, cached_code, trace_annotation, throw_flag, free_vars_copied);
}
// cache miss
CacheEntry* cache_entry = extract_cache_entry(extra);
@ -719,7 +725,8 @@ static PyObject* _custom_eval_frame(
// Re-enable custom behavior
eval_frame_callback_set(callback);
*should_clear_frame = 1;
return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry), throw_flag, free_vars_copied);
return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry),
CacheEntry_get_trace_annotation(new_cache_entry), throw_flag, free_vars_copied);
} else {
DEBUG_TRACE("create skip %s", get_frame_name(frame));
Py_DECREF(result);

View File

@ -109,10 +109,12 @@ bool backend_match(PyObject* saved_backend, PyObject* backend) {
return true;
}
PyObject* lookup(
void lookup(
ExtraState* extra_state,
PyObject* f_locals,
PyObject* backend) {
PyObject* backend,
PyObject** maybe_cached_code,
const char** trace_annotation) {
size_t index = 0;
CacheEntry* found = nullptr;
py::handle locals(f_locals);
@ -145,7 +147,8 @@ PyObject* lookup(
// this function is called from C, so we cannot repropagate
// the exception
e.restore();
return nullptr;
*maybe_cached_code = nullptr;
return;
}
}
if (valid) {
@ -156,9 +159,11 @@ PyObject* lookup(
}
if (found) {
extra_state->move_to_front(found);
return found->code.ptr();
*maybe_cached_code = found->code.ptr();
*trace_annotation = found->trace_annotation.c_str();
return;
}
return py::none().ptr();
*maybe_cached_code = py::none().ptr();
}
CacheEntry* create_cache_entry(

View File

@ -143,10 +143,13 @@ ExtraState* init_and_set_extra_state(PyCodeObject* code);
// - f_locals: Borrowed
// return:
// - Py_None or PyCodeObject: Borrowed reference.
PyObject* lookup(
// - Py_None or PyObject: Trace id of the compiled code.
void lookup(
ExtraState* extra_state,
PyObject* f_locals,
PyObject* backend);
PyObject* backend,
PyObject** maybe_cached_code,
const char** trace_annotation);
// Create a new cache entry at extra_state holding on to guarded_code.
// Ownership contract

View File

@ -67,6 +67,7 @@ void initDynamoBindings(PyObject* torch) {
.def_readonly("check_fn", &CacheEntry::check_fn)
.def_readonly("code", &CacheEntry::code)
.def_readonly("compile_id", &CacheEntry::compile_id)
.def_readonly("trace_annotation", &CacheEntry::trace_annotation)
.def_property_readonly("next", &CacheEntry::next);
py::class_<ExtraState>(m, "_ExtraState")