[dynamo] delete dynamo cache entry when guard function is invalidated [attempt 2] (#119107)

Attempt #2 for https://github.com/pytorch/pytorch/pull/117875 to fix https://github.com/pytorch/pytorch/issues/112090.

Summary of changes:
- ~Changed CacheEntry linked list into a doubly-linked list structure to support deletion.~ (done by C++ refactor)
- Added CacheEntry and ExtraState borrowed references to GuardFn so that GuardFn can tell ExtraState to delete CacheEntry when the GuardFn is invalidated.
- ~Added ExtraState raw reference to CacheEntry so that we can get ExtraState to correctly point to the first CacheEntry if it gets deleted.~ (done by C++ refactor)
- CacheEntry destructor needs to reset GuardFn refs to ExtraState/CacheEntry in order to prevent use-after-free.
- code_context values that are nn.GraphModules need to be weakrefs in order to prevent circular references.
- Added tests that check for memory leaks and cache deletion operations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119107
Approved by: https://github.com/jansel
This commit is contained in:
William Wen 2024-02-06 16:07:34 -08:00 committed by PyTorch MergeBot
parent fcc36de9d6
commit ee1c2449f7
13 changed files with 169 additions and 29 deletions

View File

@ -9412,6 +9412,44 @@ fn
c2 = _debug_get_cache_entry_list(fn.__code__)
self.assertEqual(len(c2), 0)
@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_module_free(self):
"""Test that CUDA memory is freed when a model goes out of scope"""
class Mod(torch.nn.Module):
def __init__(self):
super(Mod, self).__init__()
self.fc = torch.nn.Linear(10000, 10000)
def forward(self, out):
return self.fc(out)
def run(compile):
mod = Mod().cuda()
if compile:
mod = torch.compile(mod, backend="eager")
inp = torch.rand(10000, 10000).cuda()
mod(inp)
def clean_and_report_memory():
import gc
gc.collect()
return torch.cuda.memory_allocated()
run(False)
# mem1 = clean_and_report_memory()
run(True)
mem2 = clean_and_report_memory()
torch._dynamo.reset_code_caches()
mem3 = clean_and_report_memory()
# it's possible for dynamo to hold on to more memory
# even after a _dynamo.reset[_code_caches], so we omit the following check.
# self.assertEqual(mem1, mem2)
self.assertEqual(mem2, mem3)
def test_dynamo_cache_move_to_front(self):
class Mod(torch.nn.Module):
def __init__(self):
@ -9445,6 +9483,56 @@ fn
c2 = _debug_get_cache_entry_list(fn.__code__)
self.assertIs(c1[1], c2[0])
def test_dynamo_cache_invalidate(self):
class Mod(torch.nn.Module):
def __init__(self):
super(Mod, self).__init__()
self.fc = torch.nn.Linear(3, 3)
def forward(self, out):
return self.fc(out)
def fn(x, mod):
return mod(x)
opt_fn = torch.compile(fn, backend="eager")
m1 = Mod()
m2 = Mod()
m3 = Mod()
inp = torch.randn(3, 3)
# NOTE: assumes that each cache entry is guarded
# on unique Mod instance
opt_fn(inp, m1)
opt_fn(inp, m2)
opt_fn(inp, m3)
c1 = _debug_get_cache_entry_list(fn.__code__)
self.assertEqual(len(c1), 3)
# move cache entry to front
opt_fn(inp, m2)
c2 = _debug_get_cache_entry_list(fn.__code__)
self.assertIs(c1[1], c2[0])
# delete center of cache
del m3
c3 = _debug_get_cache_entry_list(fn.__code__)
self.assertEqual(len(c3), 2)
self.assertIs(c3[0], c2[0])
self.assertIs(c3[1], c2[2])
# delete end of cache
del m1
c4 = _debug_get_cache_entry_list(fn.__code__)
self.assertEqual(len(c4), 1)
self.assertIs(c4[0], c3[0])
del m2
c5 = _debug_get_cache_entry_list(fn.__code__)
self.assertEqual(len(c5), 0)
class TestTracer(JitTestCase):
def test_jit_save(self):

View File

@ -18,4 +18,7 @@ class _CacheEntry:
code: types.CodeType
next: Optional[_CacheEntry]
class _ExtraState:
def invalidate(self, cache_entry: _CacheEntry): ...
def _debug_get_cache_entry_list(code: types.CodeType) -> List[_CacheEntry]: ...

View File

@ -67,12 +67,7 @@ if torch.manual_seed is torch.random.manual_seed:
def reset() -> None:
"""Clear all compile caches and restore initial state"""
with convert_frame.compile_lock:
for weak_code in (
convert_frame.input_codes.seen + convert_frame.output_codes.seen
):
code = weak_code()
if code:
reset_code(code)
reset_code_caches()
convert_frame.input_codes.clear()
convert_frame.output_codes.clear()
orig_code_map.clear()
@ -82,4 +77,15 @@ def reset() -> None:
_reset_guarded_backend_cache()
reset_frame_count()
torch._C._dynamo.compiled_autograd.clear_cache()
def reset_code_caches() -> None:
"""Clear compile caches that are keyed by code objects"""
with convert_frame.compile_lock:
for weak_code in (
convert_frame.input_codes.seen + convert_frame.output_codes.seen
):
code = weak_code()
if code:
reset_code(code)
code_context.clear()

View File

@ -20,6 +20,7 @@ import threading
import traceback
import types
import warnings
import weakref
from enum import Enum
from os.path import dirname, join
from typing import (
@ -384,7 +385,9 @@ class _TorchDynamoContext:
# Assume that the underlying node metadata of `fn`,
# a GraphModule instance, accurately represents
# all instances of type(fn).
code_context.get_context(fn.forward.__code__)["orig_graphmodule"] = fn
code_context.get_context(fn.forward.__code__)[
"orig_graphmodule"
] = weakref.ref(fn)
# Optimize the forward method of torch.nn.Module object
if isinstance(fn, torch.nn.Module):

View File

@ -55,7 +55,7 @@ from torch.utils.weak import TensorWeakRef
from . import config, convert_frame, exc, mutation_guard
from .eval_frame import set_guard_error_hook
from .source import DefaultsSource, LocalSource, TypeSource
from .types import GuardedCode, GuardFail, GuardFn # noqa: F401
from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401
from .utils import (
common_constant_types,
dict_keys_repr,
@ -931,16 +931,15 @@ def must_add_nn_module_guards(guard):
)
class DeletedGuardFn:
pass
# NB: Naively, you'd expect this to only be a function that produces
# the callable that constitutes the guard. However, there is some
# delicate handling for invalidating this check function when the
# locals/globals get invalidated, so there's some extra state
# we have to hold in this manager class.
#
# TODO: this object has reference cycle with itself, via check_fn which
# references back to CheckFunction via ___guarded_code in closure_vars.
# Ideally, there shouldn't be any ref cycle so that guards are
# promptly disposed of.
class CheckFunctionManager:
def __init__(
self,
@ -948,7 +947,6 @@ class CheckFunctionManager:
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None,
):
guards = output_graph.guards if output_graph else None
self.valid = True
self._weakrefs: Dict[int, ReferenceType[object]] = {}
self.output_graph = output_graph
@ -1025,7 +1023,7 @@ class CheckFunctionManager:
guards_log.debug("GUARDS:")
# Don't report this guard, it's always the same, useless!
code_parts = ["___guarded_code.valid", "___check_global_state()"]
code_parts = ["___check_global_state()"]
verbose_code_parts = code_parts[:]
def add_code_part(code, guard, log_only=False):
@ -1157,7 +1155,6 @@ class CheckFunctionManager:
# we should only hit this case in NopTests()
global_state = convert_frame.GlobalStateGuard()
closure_vars = {
"___guarded_code": self,
"___check_tensors": check_tensors_fn,
"___check_tensors_verbose": check_tensors_verbose_fn,
"___check_global_state": global_state.check,
@ -1194,14 +1191,28 @@ class CheckFunctionManager:
# Grab only G, but preserve "G" because guards access it as "G"
guard_fn.global_scope = globals_for_guard_fn
guard_fn.guard_fail_fn = guard_fail_fn
# will be populated by a non-owning reference to CacheEntry/ExtraState
# when the CacheEntry is constructed
guard_fn.cache_entry = None
guard_fn.extra_state = None
return guard_fn
def invalidate(self):
# A weakref is no longer valid, self.check_fn should return false
# TODO(janimesh) - Free up cache entry after the cache entry formation
# is in python, and the underlying data structure is a doubly linked
# list.
self.valid = False
# Some tests reveal that CheckFunctionManager has no attribute
# check_fn, but this case should not be of any concern.
# This case doesn't seem easy to repro.
if (
hasattr(self, "check_fn")
and self.check_fn is not DeletedGuardFn
and (cache_entry := self.check_fn.cache_entry) is not None
and (extra_state := self.check_fn.extra_state) is not None
):
assert isinstance(cache_entry, CacheEntry)
assert isinstance(extra_state, ExtraState)
extra_state.invalidate(cache_entry)
self.check_fn.cache_entry = None
self.check_fn.extra_state = None
self.check_fn = DeletedGuardFn
def id_ref(self, obj):
"""add a weakref, return the id"""

View File

@ -1669,8 +1669,8 @@ class SubgraphTracer(fx.Tracer):
is_retracing = False
if tx.f_code is not self._cur_code:
orig_graphmodule_maybe = code_context.get_context(tx.f_code).get(
"orig_graphmodule", None
)
"orig_graphmodule", lambda: None
)()
if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule):
is_retracing = True
self._orig_gm_meta = [

View File

@ -2194,12 +2194,12 @@ class InstructionTranslator(InstructionTranslatorBase):
# Add original GraphModule context to the resume function to handle
# the case of a graph break while tracing a GraphModule
orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
"orig_graphmodule", None
)
"orig_graphmodule", lambda: None
)()
if orig_graphmodule_maybe is not None:
code_context.get_context(new_code)[
"orig_graphmodule"
] = orig_graphmodule_maybe
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
orig_graphmodule_maybe
)
if new_code.co_freevars:
cg.make_function_with_closure(name, new_code, True, stack_len)
@ -2347,7 +2347,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
# but it is enough to add a context for `forward` in case it is called.
code_context.get_context(module.forward.__code__)[
"orig_graphmodule"
] = module
] = weakref.ref(module)
tracer: InliningInstructionTranslator
if is_generator(code):

View File

@ -19,6 +19,8 @@ import torch
# and a `code` field for the code object.
CacheEntry = torch._C._dynamo.eval_frame._CacheEntry
ExtraState = torch._C._dynamo.eval_frame._ExtraState
# We use a dict to store additional data per frame.
FrameState = Dict[Any, Any]
@ -37,6 +39,8 @@ class GuardFn(Protocol):
verbose_code_parts: List[str]
global_scope: Dict[str, object]
guard_fail_fn: Optional[Callable[[GuardFail], None]]
cache_entry: Optional[CacheEntry]
extra_state: Optional[ExtraState]
# maps locals of user function to bool
def __call__(self, f_locals: Dict[str, object]) -> bool:

View File

@ -8,6 +8,12 @@ CacheEntry::CacheEntry(const py::handle& guarded_code) {
this->code = guarded_code.attr("code");
}
CacheEntry::~CacheEntry() {
// prevent check_fn from use-after-free when invalidating
this->check_fn.attr("cache_entry") = py::none();
this->check_fn.attr("extra_state") = py::none();
}
py::object CacheEntry::next() {
NULL_CHECK(this->_owner);
auto it = this->_owner_loc;

View File

@ -49,6 +49,7 @@ typedef struct VISIBILITY_HIDDEN CacheEntry {
std::list<CacheEntry>::iterator _owner_loc;
CacheEntry(const py::handle& guarded_code);
~CacheEntry();
// Warning: returns a reference whose lifetime is controlled by C++
py::object next();

View File

@ -22,6 +22,13 @@ void ExtraState::move_to_front(CacheEntry* cache_entry) {
cache_entry->_owner_loc);
}
void ExtraState::invalidate(CacheEntry* cache_entry) {
CHECK(cache_entry->_owner == this);
CHECK(!this->cache_entry_list.empty());
CHECK(cache_entry == &*cache_entry->_owner_loc);
this->cache_entry_list.erase(cache_entry->_owner_loc);
}
CacheEntry* extract_cache_entry(ExtraState* extra_state) {
if (extra_state == NULL || extra_state == SKIP_CODE) {
return NULL;
@ -110,6 +117,13 @@ CacheEntry* create_cache_entry(
auto new_iter = extra_state->cache_entry_list.begin();
new_iter->_owner = extra_state;
new_iter->_owner_loc = new_iter;
// Set check_fn references to extra_state and CacheEntry
// Warning: lifetime is controlled by C++!
py::handle check_fn = py::handle(guarded_code).attr("check_fn");
check_fn.attr("cache_entry") =
py::cast(*new_iter, py::return_value_policy::reference);
check_fn.attr("extra_state") =
py::cast(extra_state, py::return_value_policy::reference);
return &*new_iter;
}

View File

@ -41,6 +41,7 @@ typedef struct VISIBILITY_HIDDEN ExtraState {
CacheEntry* get_first_entry();
void move_to_front(CacheEntry* cache_entry);
void invalidate(CacheEntry* cache_entry);
} ExtraState;
#else

View File

@ -44,6 +44,9 @@ void initDynamoBindings(PyObject* torch) {
.def_readonly("code", &CacheEntry::code)
.def_property_readonly("next", &CacheEntry::next);
py::class_<ExtraState>(m, "_ExtraState")
.def("invalidate", &ExtraState::invalidate);
m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list);
}