mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix PT2 Source Code Annotations (#136460)
Summary: In D60803317, we added CompileContext (trace_id) information to Kineto traces using caching when a CompileContext exits. As pointed out by some users, this gives innaccurate IDs because we are not getting the context that we is being looked up within the eval_frame. For this reason, we decided to revert that change, and go with an approach that involves getting the trace_id associated with a given CacheEntry. To do this, we add a trace_id to the GuardedCode so that it can be passed onto a CacheEntry. Then, we change the lookup function to return said trace_id alongside the code so that we can pass both into our eval function. Once we get to a Torch-Compiled Region, we can just append the context information to the name of the annotation thus bypassing any need for kwargs. Test Plan: Added more comprehensive unit test. Saw that all the trace_ids appeared within the graph. Differential Revision: D63138786 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136460 Approved by: https://github.com/ezyang
This commit is contained in:
parent
8df97d78c2
commit
9e4f24f8e5
|
|
@ -163,20 +163,34 @@ class DynamoProfilerTests(torch._dynamo.test_case.TestCase):
|
|||
)
|
||||
|
||||
def test_profiler_dynamo_compiled_region(self):
|
||||
def fn(x, y, z):
|
||||
return x @ y + z
|
||||
|
||||
opt_fn = torch._dynamo.optimize("eager")(fn)
|
||||
|
||||
inputs = [torch.rand(4, 4) for _ in range(3)]
|
||||
|
||||
for _ in range(2):
|
||||
opt_fn(*inputs)
|
||||
def fn(x, y):
|
||||
r = y.sum(dim=1)
|
||||
print(r.shape)
|
||||
return x * r
|
||||
|
||||
with torch.profiler.profile() as prof:
|
||||
opt_fn(*inputs)
|
||||
fn_c = torch.compile(fn)
|
||||
|
||||
self.assertTrue(any(e.name == "Torch-Compiled Region" for e in prof.events()))
|
||||
fn_c(
|
||||
torch.randn(10),
|
||||
torch.randn(10, 10),
|
||||
)
|
||||
|
||||
fn_c(
|
||||
torch.randn(10),
|
||||
torch.randn(10, 15),
|
||||
)
|
||||
|
||||
annotations = [e.name for e in prof.events() if "Compiled" in e.name]
|
||||
self.assertEqual(
|
||||
annotations,
|
||||
[
|
||||
"Torch-Compiled Region: 0/0",
|
||||
"Torch-Compiled Region: 1/0",
|
||||
"Torch-Compiled Region: 0/1",
|
||||
"Torch-Compiled Region: 1/0",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -804,7 +804,11 @@ def _compile(
|
|||
hooks.guard_fail_fn if hooks else None,
|
||||
)
|
||||
|
||||
guarded_code = GuardedCode(out_code, check_fn.check_fn, compile_id)
|
||||
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
|
||||
annotation_str = "Torch-Compiled Region: " + compile_id_str
|
||||
guarded_code = GuardedCode(
|
||||
out_code, check_fn.check_fn, compile_id, annotation_str
|
||||
)
|
||||
|
||||
if not output.is_empty_graph() and hooks.guard_export_fn is not None:
|
||||
# We should not run the guard_export_fn when Dynamo does not
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ class GuardedCode:
|
|||
code: types.CodeType
|
||||
check_fn: GuardFn
|
||||
compile_id: CompileId
|
||||
trace_annotation: str = "Unknown"
|
||||
|
||||
|
||||
class DynamoCallbackFn(Protocol):
|
||||
|
|
|
|||
|
|
@ -9,6 +9,13 @@ CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend)
|
|||
this->check_fn = guarded_code.attr("check_fn");
|
||||
this->code = guarded_code.attr("code");
|
||||
this->compile_id = guarded_code.attr("compile_id");
|
||||
py::object trace_annotation = guarded_code.attr("trace_annotation");
|
||||
const char* trace_annotation_str = PyUnicode_AsUTF8(trace_annotation.ptr());
|
||||
if (trace_annotation) {
|
||||
this->trace_annotation = std::string(trace_annotation_str);
|
||||
} else {
|
||||
this->trace_annotation = "Unknown";
|
||||
}
|
||||
// TODO - clean this up when enable_cpp_guard_manager is True by default
|
||||
if (py::hasattr(this->check_fn, "root")) {
|
||||
this->root_mgr = torch::dynamo::convert_to_root_guard_manager(
|
||||
|
|
@ -42,6 +49,10 @@ PyCodeObject* CacheEntry_get_code(CacheEntry* e) {
|
|||
return (PyCodeObject*)e->code.ptr();
|
||||
}
|
||||
|
||||
const char* CacheEntry_get_trace_annotation(CacheEntry* e) {
|
||||
return e->trace_annotation.c_str();
|
||||
}
|
||||
|
||||
PyObject* CacheEntry_to_obj(CacheEntry* e) {
|
||||
if (!e) {
|
||||
return py::none().release().ptr();
|
||||
|
|
|
|||
|
|
@ -54,6 +54,8 @@ typedef struct VISIBILITY_HIDDEN CacheEntry {
|
|||
ExtraState* _owner{nullptr};
|
||||
// Reference to this CacheEntry's location in owner's linked list
|
||||
std::list<CacheEntry>::iterator _owner_loc;
|
||||
// Reference to string representation of the CompileContext
|
||||
std::string trace_annotation;
|
||||
|
||||
CacheEntry(const py::handle& guarded_code, PyObject* backend);
|
||||
~CacheEntry();
|
||||
|
|
@ -69,6 +71,9 @@ C10_DIAGNOSTIC_POP()
|
|||
// Returns borrowed reference
|
||||
PyCodeObject* CacheEntry_get_code(CacheEntry* e);
|
||||
|
||||
// Returns borrowed string representation of CompileContext
|
||||
const char* CacheEntry_get_trace_annotation(CacheEntry* e);
|
||||
|
||||
// Returns a borrowed reference to CacheEntry as a PyObject
|
||||
// Warning: lifetime is controlled by C++
|
||||
PyObject* CacheEntry_to_obj(CacheEntry* e);
|
||||
|
|
|
|||
|
|
@ -481,9 +481,11 @@ inline static PyObject* eval_custom_code(
|
|||
PyThreadState* tstate,
|
||||
THP_EVAL_API_FRAME_OBJECT* frame,
|
||||
PyCodeObject* code,
|
||||
const char* trace_annotation,
|
||||
int throw_flag,
|
||||
int free_vars_copied) {
|
||||
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter("Torch-Compiled Region");
|
||||
|
||||
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(trace_annotation);
|
||||
PyObject* result = eval_custom_code_impl(
|
||||
tstate,
|
||||
frame,
|
||||
|
|
@ -615,7 +617,9 @@ static PyObject* _custom_eval_frame(
|
|||
if (callback == Py_False || extra_state_cache_limit_hit(extra)) {
|
||||
DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
|
||||
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
|
||||
PyObject* maybe_cached_code = lookup(extra, locals, backend);
|
||||
PyObject* maybe_cached_code = NULL;
|
||||
const char* trace_annotation = "";
|
||||
lookup(extra, locals, backend, &maybe_cached_code, &trace_annotation);
|
||||
_pytorch_record_function_exit(rf);
|
||||
|
||||
Py_DECREF(locals);
|
||||
|
|
@ -641,7 +645,7 @@ static PyObject* _custom_eval_frame(
|
|||
// used cached version
|
||||
DEBUG_TRACE("cache hit %s", get_frame_name(frame));
|
||||
*should_clear_frame = 1;
|
||||
return eval_custom_code(tstate, frame, cached_code, throw_flag, 0);
|
||||
return eval_custom_code(tstate, frame, cached_code, trace_annotation, throw_flag, 0);
|
||||
}
|
||||
DEBUG_CHECK(PyDict_CheckExact(locals));
|
||||
DEBUG_CHECK(PyDict_CheckExact(frame->f_globals));
|
||||
|
|
@ -653,7 +657,9 @@ static PyObject* _custom_eval_frame(
|
|||
eval_frame_callback_set(Py_None);
|
||||
|
||||
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
|
||||
PyObject* maybe_cached_code = lookup(extra, locals, backend);
|
||||
PyObject* maybe_cached_code = NULL;
|
||||
const char* trace_annotation = "";
|
||||
lookup(extra, locals, backend, &maybe_cached_code, &trace_annotation);
|
||||
_pytorch_record_function_exit(rf);
|
||||
if (maybe_cached_code == NULL) {
|
||||
// Python error
|
||||
|
|
@ -668,7 +674,7 @@ static PyObject* _custom_eval_frame(
|
|||
eval_frame_callback_set(callback);
|
||||
*should_clear_frame = 1;
|
||||
Py_DECREF(locals);
|
||||
return eval_custom_code(tstate, frame, cached_code, throw_flag, free_vars_copied);
|
||||
return eval_custom_code(tstate, frame, cached_code, trace_annotation, throw_flag, free_vars_copied);
|
||||
}
|
||||
// cache miss
|
||||
CacheEntry* cache_entry = extract_cache_entry(extra);
|
||||
|
|
@ -719,7 +725,8 @@ static PyObject* _custom_eval_frame(
|
|||
// Re-enable custom behavior
|
||||
eval_frame_callback_set(callback);
|
||||
*should_clear_frame = 1;
|
||||
return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry), throw_flag, free_vars_copied);
|
||||
return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry),
|
||||
CacheEntry_get_trace_annotation(new_cache_entry), throw_flag, free_vars_copied);
|
||||
} else {
|
||||
DEBUG_TRACE("create skip %s", get_frame_name(frame));
|
||||
Py_DECREF(result);
|
||||
|
|
|
|||
|
|
@ -109,10 +109,12 @@ bool backend_match(PyObject* saved_backend, PyObject* backend) {
|
|||
return true;
|
||||
}
|
||||
|
||||
PyObject* lookup(
|
||||
void lookup(
|
||||
ExtraState* extra_state,
|
||||
PyObject* f_locals,
|
||||
PyObject* backend) {
|
||||
PyObject* backend,
|
||||
PyObject** maybe_cached_code,
|
||||
const char** trace_annotation) {
|
||||
size_t index = 0;
|
||||
CacheEntry* found = nullptr;
|
||||
py::handle locals(f_locals);
|
||||
|
|
@ -145,7 +147,8 @@ PyObject* lookup(
|
|||
// this function is called from C, so we cannot repropagate
|
||||
// the exception
|
||||
e.restore();
|
||||
return nullptr;
|
||||
*maybe_cached_code = nullptr;
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (valid) {
|
||||
|
|
@ -156,9 +159,11 @@ PyObject* lookup(
|
|||
}
|
||||
if (found) {
|
||||
extra_state->move_to_front(found);
|
||||
return found->code.ptr();
|
||||
*maybe_cached_code = found->code.ptr();
|
||||
*trace_annotation = found->trace_annotation.c_str();
|
||||
return;
|
||||
}
|
||||
return py::none().ptr();
|
||||
*maybe_cached_code = py::none().ptr();
|
||||
}
|
||||
|
||||
CacheEntry* create_cache_entry(
|
||||
|
|
|
|||
|
|
@ -143,10 +143,13 @@ ExtraState* init_and_set_extra_state(PyCodeObject* code);
|
|||
// - f_locals: Borrowed
|
||||
// return:
|
||||
// - Py_None or PyCodeObject: Borrowed reference.
|
||||
PyObject* lookup(
|
||||
// - Py_None or PyObject: Trace id of the compiled code.
|
||||
void lookup(
|
||||
ExtraState* extra_state,
|
||||
PyObject* f_locals,
|
||||
PyObject* backend);
|
||||
PyObject* backend,
|
||||
PyObject** maybe_cached_code,
|
||||
const char** trace_annotation);
|
||||
|
||||
// Create a new cache entry at extra_state holding on to guarded_code.
|
||||
// Ownership contract
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ void initDynamoBindings(PyObject* torch) {
|
|||
.def_readonly("check_fn", &CacheEntry::check_fn)
|
||||
.def_readonly("code", &CacheEntry::code)
|
||||
.def_readonly("compile_id", &CacheEntry::compile_id)
|
||||
.def_readonly("trace_annotation", &CacheEntry::trace_annotation)
|
||||
.def_property_readonly("next", &CacheEntry::next);
|
||||
|
||||
py::class_<ExtraState>(m, "_ExtraState")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user