[Dynamo] Update the implementation of _debug_get_cache_entry_list (#108335)

In https://github.com/pytorch/pytorch/pull/106673 , I created a private API `_debug_get_cache_entry_list` to help pull out cache entries from compiled functions.

Recently, I find that @anijain2305 commented in the code that this API should be revisited, and so I created this PR.

First, this API cannot be removed even if cache entry becomes a first-class python class`torch._C._dynamo.eval_frame._CacheEntry`. The facts that `extra_index` is static, and `get_extra_state` is inline static, make them not accessible elsewhere. This API `_debug_get_cache_entry_list` is the only way for users to get all the cache entries from code.

Second, since the`torch._C._dynamo.eval_frame._CacheEntry` class is a python class, I simplified the C-part code, and remove the necessity of creating a namedtuple for this in the python code.

Third, I also add a small improvement, that if the argument is a function, we can automatically pass its `__code__` to the API.

The above change will slightly change the output, from list of named tuple to list of `torch._C._dynamo.eval_frame._CacheEntry`. I will update the corresponding docs that use this API.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108335
Approved by: https://github.com/jansel, https://github.com/anijain2305
This commit is contained in:
youkaichao 2023-09-02 16:38:59 +00:00 committed by PyTorch MergeBot
parent de58600126
commit b9fc6d7ded
3 changed files with 24 additions and 29 deletions

View File

@ -92,13 +92,13 @@ class MiscTests(torch._dynamo.test_case.TestCase):
return x + 1
torch.compile(f)(torch.randn(5, 5, 5))
entries = _debug_get_cache_entry_list(f.__code__)
entries = _debug_get_cache_entry_list(f)
self.assertTrue(len(entries) > 0)
def g(x):
return x + 2
entries = _debug_get_cache_entry_list(g.__code__)
entries = _debug_get_cache_entry_list(g)
self.assertTrue(len(entries) == 0)
try:

View File

@ -12,7 +12,6 @@ import threading
import traceback
import types
import warnings
from collections import namedtuple
from enum import Enum
from os.path import dirname, join
from typing import (
@ -95,15 +94,26 @@ DONT_WRAP_FILES = {
}
CacheEntry = namedtuple("CacheEntry", "check_fn, code")
# This class has a `check_fn` field for the guard,
# and a `code` field for the code object.
CacheEntry = torch._C._dynamo.eval_frame._CacheEntry
def _debug_get_cache_entry_list(code: types.CodeType) -> List[CacheEntry]:
def _debug_get_cache_entry_list(
code: Union[types.CodeType, Callable[..., Any]]
) -> List[CacheEntry]: # type: ignore[valid-type]
"""
Given a code object, retrieve the cache entries stored in this code.
Given a code object or a callable object, retrieve the cache entries
stored in this code.
"""
cache_list = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code)
return list(map(CacheEntry._make, cache_list))
if callable(code):
code = code.__code__
cache_head = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code)
cache_list = []
while cache_head is not None:
cache_list.append(cache_head)
cache_head = cache_head.next
return cache_list
class OptimizedModule(torch.nn.Module):

View File

@ -583,8 +583,7 @@ Debugger helper functions.
*/
PyObject* _debug_get_cache_entry_list(PyObject* self, PyObject* args) {
// TODO(anijain2305) - CacheEntry being the first class Python object might
// obviate the need of this function. Revisit.
// get the cache entry out of a code object
PyObject* object = NULL;
if (!PyArg_ParseTuple(args, "O", &object)) {
return NULL;
@ -597,26 +596,12 @@ PyObject* _debug_get_cache_entry_list(PyObject* self, PyObject* args) {
ExtraState* extra = get_extra_state(code);
CacheEntry* current_node = extract_cache_entry(extra);
PyObject* outer_list = PyList_New(0);
if (!outer_list) {
return NULL; // Return NULL if failed to create list
if (current_node == NULL)
{
Py_RETURN_NONE;
}
while (current_node != NULL && current_node != (CacheEntry*)Py_None) {
// Creating a new Python tuple for the check_fn and code of current CacheEntry
PyObject* inner_list = PyTuple_Pack(2, current_node->check_fn, current_node->code);
int flag = PyList_Append(outer_list, inner_list); // Add the inner list to the outer list
Py_DECREF(inner_list); // Decrement our own reference
if (flag < 0) {
Py_DECREF(outer_list); // Clean up if failed to append
return NULL;
}
// Move to the next node in the linked list
current_node = current_node->next;
}
// Return the outer list
return outer_list;
Py_INCREF(current_node);
return (PyObject*)current_node;
}
static inline PyObject* call_callback(