mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo] Update the implementation of _debug_get_cache_entry_list (#108335)
In https://github.com/pytorch/pytorch/pull/106673 , I created a private API `_debug_get_cache_entry_list` to help pull out cache entries from compiled functions. Recently, I find that @anijain2305 commented in the code that this API should be revisited, and so I created this PR. First, this API cannot be removed even if cache entry becomes a first-class python class`torch._C._dynamo.eval_frame._CacheEntry`. The facts that `extra_index` is static, and `get_extra_state` is inline static, make them not accessible elsewhere. This API `_debug_get_cache_entry_list` is the only way for users to get all the cache entries from code. Second, since the`torch._C._dynamo.eval_frame._CacheEntry` class is a python class, I simplified the C-part code, and remove the necessity of creating a namedtuple for this in the python code. Third, I also add a small improvement, that if the argument is a function, we can automatically pass its `__code__` to the API. The above change will slightly change the output, from list of named tuple to list of `torch._C._dynamo.eval_frame._CacheEntry`. I will update the corresponding docs that use this API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108335 Approved by: https://github.com/jansel, https://github.com/anijain2305
This commit is contained in:
parent
de58600126
commit
b9fc6d7ded
|
|
@ -92,13 +92,13 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
return x + 1
|
||||
|
||||
torch.compile(f)(torch.randn(5, 5, 5))
|
||||
entries = _debug_get_cache_entry_list(f.__code__)
|
||||
entries = _debug_get_cache_entry_list(f)
|
||||
self.assertTrue(len(entries) > 0)
|
||||
|
||||
def g(x):
|
||||
return x + 2
|
||||
|
||||
entries = _debug_get_cache_entry_list(g.__code__)
|
||||
entries = _debug_get_cache_entry_list(g)
|
||||
self.assertTrue(len(entries) == 0)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ import threading
|
|||
import traceback
|
||||
import types
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from enum import Enum
|
||||
from os.path import dirname, join
|
||||
from typing import (
|
||||
|
|
@ -95,15 +94,26 @@ DONT_WRAP_FILES = {
|
|||
}
|
||||
|
||||
|
||||
CacheEntry = namedtuple("CacheEntry", "check_fn, code")
|
||||
# This class has a `check_fn` field for the guard,
|
||||
# and a `code` field for the code object.
|
||||
CacheEntry = torch._C._dynamo.eval_frame._CacheEntry
|
||||
|
||||
|
||||
def _debug_get_cache_entry_list(code: types.CodeType) -> List[CacheEntry]:
|
||||
def _debug_get_cache_entry_list(
|
||||
code: Union[types.CodeType, Callable[..., Any]]
|
||||
) -> List[CacheEntry]: # type: ignore[valid-type]
|
||||
"""
|
||||
Given a code object, retrieve the cache entries stored in this code.
|
||||
Given a code object or a callable object, retrieve the cache entries
|
||||
stored in this code.
|
||||
"""
|
||||
cache_list = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code)
|
||||
return list(map(CacheEntry._make, cache_list))
|
||||
if callable(code):
|
||||
code = code.__code__
|
||||
cache_head = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code)
|
||||
cache_list = []
|
||||
while cache_head is not None:
|
||||
cache_list.append(cache_head)
|
||||
cache_head = cache_head.next
|
||||
return cache_list
|
||||
|
||||
|
||||
class OptimizedModule(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -583,8 +583,7 @@ Debugger helper functions.
|
|||
*/
|
||||
|
||||
PyObject* _debug_get_cache_entry_list(PyObject* self, PyObject* args) {
|
||||
// TODO(anijain2305) - CacheEntry being the first class Python object might
|
||||
// obviate the need of this function. Revisit.
|
||||
// get the cache entry out of a code object
|
||||
PyObject* object = NULL;
|
||||
if (!PyArg_ParseTuple(args, "O", &object)) {
|
||||
return NULL;
|
||||
|
|
@ -597,26 +596,12 @@ PyObject* _debug_get_cache_entry_list(PyObject* self, PyObject* args) {
|
|||
|
||||
ExtraState* extra = get_extra_state(code);
|
||||
CacheEntry* current_node = extract_cache_entry(extra);
|
||||
|
||||
PyObject* outer_list = PyList_New(0);
|
||||
if (!outer_list) {
|
||||
return NULL; // Return NULL if failed to create list
|
||||
if (current_node == NULL)
|
||||
{
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
while (current_node != NULL && current_node != (CacheEntry*)Py_None) {
|
||||
// Creating a new Python tuple for the check_fn and code of current CacheEntry
|
||||
PyObject* inner_list = PyTuple_Pack(2, current_node->check_fn, current_node->code);
|
||||
int flag = PyList_Append(outer_list, inner_list); // Add the inner list to the outer list
|
||||
Py_DECREF(inner_list); // Decrement our own reference
|
||||
if (flag < 0) {
|
||||
Py_DECREF(outer_list); // Clean up if failed to append
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Move to the next node in the linked list
|
||||
current_node = current_node->next;
|
||||
}
|
||||
// Return the outer list
|
||||
return outer_list;
|
||||
Py_INCREF(current_node);
|
||||
return (PyObject*)current_node;
|
||||
}
|
||||
|
||||
static inline PyObject* call_callback(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user