mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] delete dynamo cache entry when guard function is invalidated [attempt 2] (#119107)
Attempt #2 for https://github.com/pytorch/pytorch/pull/117875 to fix https://github.com/pytorch/pytorch/issues/112090. Summary of changes: - ~Changed CacheEntry linked list into a doubly-linked list structure to support deletion.~ (done by C++ refactor) - Added CacheEntry and ExtraState borrowed references to GuardFn so that GuardFn can tell ExtraState to delete CacheEntry when the GuardFn is invalidated. - ~Added ExtraState raw reference to CacheEntry so that we can get ExtraState to correctly point to the first CacheEntry if it gets deleted.~ (done by C++ refactor) - CacheEntry destructor needs to reset GuardFn refs to ExtraState/CacheEntry in order to prevent use-after-free. - code_context values that are nn.GraphModules need to be weakrefs in order to prevent circular references. - Added tests that check for memory leaks and cache deletion operations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119107 Approved by: https://github.com/jansel
This commit is contained in:
parent
fcc36de9d6
commit
ee1c2449f7
|
|
@ -9412,6 +9412,44 @@ fn
|
|||
c2 = _debug_get_cache_entry_list(fn.__code__)
|
||||
self.assertEqual(len(c2), 0)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "requires cuda")
|
||||
def test_module_free(self):
|
||||
"""Test that CUDA memory is freed when a model goes out of scope"""
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Mod, self).__init__()
|
||||
self.fc = torch.nn.Linear(10000, 10000)
|
||||
|
||||
def forward(self, out):
|
||||
return self.fc(out)
|
||||
|
||||
def run(compile):
|
||||
mod = Mod().cuda()
|
||||
if compile:
|
||||
mod = torch.compile(mod, backend="eager")
|
||||
inp = torch.rand(10000, 10000).cuda()
|
||||
mod(inp)
|
||||
|
||||
def clean_and_report_memory():
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
return torch.cuda.memory_allocated()
|
||||
|
||||
run(False)
|
||||
# mem1 = clean_and_report_memory()
|
||||
run(True)
|
||||
mem2 = clean_and_report_memory()
|
||||
torch._dynamo.reset_code_caches()
|
||||
mem3 = clean_and_report_memory()
|
||||
|
||||
# it's possible for dynamo to hold on to more memory
|
||||
# even after a _dynamo.reset[_code_caches], so we omit the following check.
|
||||
# self.assertEqual(mem1, mem2)
|
||||
|
||||
self.assertEqual(mem2, mem3)
|
||||
|
||||
def test_dynamo_cache_move_to_front(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -9445,6 +9483,56 @@ fn
|
|||
c2 = _debug_get_cache_entry_list(fn.__code__)
|
||||
self.assertIs(c1[1], c2[0])
|
||||
|
||||
def test_dynamo_cache_invalidate(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Mod, self).__init__()
|
||||
self.fc = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, out):
|
||||
return self.fc(out)
|
||||
|
||||
def fn(x, mod):
|
||||
return mod(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager")
|
||||
|
||||
m1 = Mod()
|
||||
m2 = Mod()
|
||||
m3 = Mod()
|
||||
inp = torch.randn(3, 3)
|
||||
|
||||
# NOTE: assumes that each cache entry is guarded
|
||||
# on unique Mod instance
|
||||
opt_fn(inp, m1)
|
||||
opt_fn(inp, m2)
|
||||
opt_fn(inp, m3)
|
||||
|
||||
c1 = _debug_get_cache_entry_list(fn.__code__)
|
||||
self.assertEqual(len(c1), 3)
|
||||
|
||||
# move cache entry to front
|
||||
opt_fn(inp, m2)
|
||||
c2 = _debug_get_cache_entry_list(fn.__code__)
|
||||
self.assertIs(c1[1], c2[0])
|
||||
|
||||
# delete center of cache
|
||||
del m3
|
||||
c3 = _debug_get_cache_entry_list(fn.__code__)
|
||||
self.assertEqual(len(c3), 2)
|
||||
self.assertIs(c3[0], c2[0])
|
||||
self.assertIs(c3[1], c2[2])
|
||||
|
||||
# delete end of cache
|
||||
del m1
|
||||
c4 = _debug_get_cache_entry_list(fn.__code__)
|
||||
self.assertEqual(len(c4), 1)
|
||||
self.assertIs(c4[0], c3[0])
|
||||
|
||||
del m2
|
||||
c5 = _debug_get_cache_entry_list(fn.__code__)
|
||||
self.assertEqual(len(c5), 0)
|
||||
|
||||
|
||||
class TestTracer(JitTestCase):
|
||||
def test_jit_save(self):
|
||||
|
|
|
|||
|
|
@ -18,4 +18,7 @@ class _CacheEntry:
|
|||
code: types.CodeType
|
||||
next: Optional[_CacheEntry]
|
||||
|
||||
class _ExtraState:
|
||||
def invalidate(self, cache_entry: _CacheEntry): ...
|
||||
|
||||
def _debug_get_cache_entry_list(code: types.CodeType) -> List[_CacheEntry]: ...
|
||||
|
|
|
|||
|
|
@ -67,12 +67,7 @@ if torch.manual_seed is torch.random.manual_seed:
|
|||
def reset() -> None:
|
||||
"""Clear all compile caches and restore initial state"""
|
||||
with convert_frame.compile_lock:
|
||||
for weak_code in (
|
||||
convert_frame.input_codes.seen + convert_frame.output_codes.seen
|
||||
):
|
||||
code = weak_code()
|
||||
if code:
|
||||
reset_code(code)
|
||||
reset_code_caches()
|
||||
convert_frame.input_codes.clear()
|
||||
convert_frame.output_codes.clear()
|
||||
orig_code_map.clear()
|
||||
|
|
@ -82,4 +77,15 @@ def reset() -> None:
|
|||
_reset_guarded_backend_cache()
|
||||
reset_frame_count()
|
||||
torch._C._dynamo.compiled_autograd.clear_cache()
|
||||
|
||||
|
||||
def reset_code_caches() -> None:
|
||||
"""Clear compile caches that are keyed by code objects"""
|
||||
with convert_frame.compile_lock:
|
||||
for weak_code in (
|
||||
convert_frame.input_codes.seen + convert_frame.output_codes.seen
|
||||
):
|
||||
code = weak_code()
|
||||
if code:
|
||||
reset_code(code)
|
||||
code_context.clear()
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import threading
|
|||
import traceback
|
||||
import types
|
||||
import warnings
|
||||
import weakref
|
||||
from enum import Enum
|
||||
from os.path import dirname, join
|
||||
from typing import (
|
||||
|
|
@ -384,7 +385,9 @@ class _TorchDynamoContext:
|
|||
# Assume that the underlying node metadata of `fn`,
|
||||
# a GraphModule instance, accurately represents
|
||||
# all instances of type(fn).
|
||||
code_context.get_context(fn.forward.__code__)["orig_graphmodule"] = fn
|
||||
code_context.get_context(fn.forward.__code__)[
|
||||
"orig_graphmodule"
|
||||
] = weakref.ref(fn)
|
||||
|
||||
# Optimize the forward method of torch.nn.Module object
|
||||
if isinstance(fn, torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ from torch.utils.weak import TensorWeakRef
|
|||
from . import config, convert_frame, exc, mutation_guard
|
||||
from .eval_frame import set_guard_error_hook
|
||||
from .source import DefaultsSource, LocalSource, TypeSource
|
||||
from .types import GuardedCode, GuardFail, GuardFn # noqa: F401
|
||||
from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401
|
||||
from .utils import (
|
||||
common_constant_types,
|
||||
dict_keys_repr,
|
||||
|
|
@ -931,16 +931,15 @@ def must_add_nn_module_guards(guard):
|
|||
)
|
||||
|
||||
|
||||
class DeletedGuardFn:
|
||||
pass
|
||||
|
||||
|
||||
# NB: Naively, you'd expect this to only be a function that produces
|
||||
# the callable that constitutes the guard. However, there is some
|
||||
# delicate handling for invalidating this check function when the
|
||||
# locals/globals get invalidated, so there's some extra state
|
||||
# we have to hold in this manager class.
|
||||
#
|
||||
# TODO: this object has reference cycle with itself, via check_fn which
|
||||
# references back to CheckFunction via ___guarded_code in closure_vars.
|
||||
# Ideally, there shouldn't be any ref cycle so that guards are
|
||||
# promptly disposed of.
|
||||
class CheckFunctionManager:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -948,7 +947,6 @@ class CheckFunctionManager:
|
|||
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None,
|
||||
):
|
||||
guards = output_graph.guards if output_graph else None
|
||||
self.valid = True
|
||||
self._weakrefs: Dict[int, ReferenceType[object]] = {}
|
||||
self.output_graph = output_graph
|
||||
|
||||
|
|
@ -1025,7 +1023,7 @@ class CheckFunctionManager:
|
|||
guards_log.debug("GUARDS:")
|
||||
|
||||
# Don't report this guard, it's always the same, useless!
|
||||
code_parts = ["___guarded_code.valid", "___check_global_state()"]
|
||||
code_parts = ["___check_global_state()"]
|
||||
verbose_code_parts = code_parts[:]
|
||||
|
||||
def add_code_part(code, guard, log_only=False):
|
||||
|
|
@ -1157,7 +1155,6 @@ class CheckFunctionManager:
|
|||
# we should only hit this case in NopTests()
|
||||
global_state = convert_frame.GlobalStateGuard()
|
||||
closure_vars = {
|
||||
"___guarded_code": self,
|
||||
"___check_tensors": check_tensors_fn,
|
||||
"___check_tensors_verbose": check_tensors_verbose_fn,
|
||||
"___check_global_state": global_state.check,
|
||||
|
|
@ -1194,14 +1191,28 @@ class CheckFunctionManager:
|
|||
# Grab only G, but preserve "G" because guards access it as "G"
|
||||
guard_fn.global_scope = globals_for_guard_fn
|
||||
guard_fn.guard_fail_fn = guard_fail_fn
|
||||
# will be populated by a non-owning reference to CacheEntry/ExtraState
|
||||
# when the CacheEntry is constructed
|
||||
guard_fn.cache_entry = None
|
||||
guard_fn.extra_state = None
|
||||
return guard_fn
|
||||
|
||||
def invalidate(self):
|
||||
# A weakref is no longer valid, self.check_fn should return false
|
||||
# TODO(janimesh) - Free up cache entry after the cache entry formation
|
||||
# is in python, and the underlying data structure is a doubly linked
|
||||
# list.
|
||||
self.valid = False
|
||||
# Some tests reveal that CheckFunctionManager has no attribute
|
||||
# check_fn, but this case should not be of any concern.
|
||||
# This case doesn't seem easy to repro.
|
||||
if (
|
||||
hasattr(self, "check_fn")
|
||||
and self.check_fn is not DeletedGuardFn
|
||||
and (cache_entry := self.check_fn.cache_entry) is not None
|
||||
and (extra_state := self.check_fn.extra_state) is not None
|
||||
):
|
||||
assert isinstance(cache_entry, CacheEntry)
|
||||
assert isinstance(extra_state, ExtraState)
|
||||
extra_state.invalidate(cache_entry)
|
||||
self.check_fn.cache_entry = None
|
||||
self.check_fn.extra_state = None
|
||||
self.check_fn = DeletedGuardFn
|
||||
|
||||
def id_ref(self, obj):
|
||||
"""add a weakref, return the id"""
|
||||
|
|
|
|||
|
|
@ -1669,8 +1669,8 @@ class SubgraphTracer(fx.Tracer):
|
|||
is_retracing = False
|
||||
if tx.f_code is not self._cur_code:
|
||||
orig_graphmodule_maybe = code_context.get_context(tx.f_code).get(
|
||||
"orig_graphmodule", None
|
||||
)
|
||||
"orig_graphmodule", lambda: None
|
||||
)()
|
||||
if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule):
|
||||
is_retracing = True
|
||||
self._orig_gm_meta = [
|
||||
|
|
|
|||
|
|
@ -2194,12 +2194,12 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
# Add original GraphModule context to the resume function to handle
|
||||
# the case of a graph break while tracing a GraphModule
|
||||
orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
|
||||
"orig_graphmodule", None
|
||||
)
|
||||
"orig_graphmodule", lambda: None
|
||||
)()
|
||||
if orig_graphmodule_maybe is not None:
|
||||
code_context.get_context(new_code)[
|
||||
"orig_graphmodule"
|
||||
] = orig_graphmodule_maybe
|
||||
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
|
||||
orig_graphmodule_maybe
|
||||
)
|
||||
|
||||
if new_code.co_freevars:
|
||||
cg.make_function_with_closure(name, new_code, True, stack_len)
|
||||
|
|
@ -2347,7 +2347,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
# but it is enough to add a context for `forward` in case it is called.
|
||||
code_context.get_context(module.forward.__code__)[
|
||||
"orig_graphmodule"
|
||||
] = module
|
||||
] = weakref.ref(module)
|
||||
|
||||
tracer: InliningInstructionTranslator
|
||||
if is_generator(code):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ import torch
|
|||
# and a `code` field for the code object.
|
||||
CacheEntry = torch._C._dynamo.eval_frame._CacheEntry
|
||||
|
||||
ExtraState = torch._C._dynamo.eval_frame._ExtraState
|
||||
|
||||
# We use a dict to store additional data per frame.
|
||||
FrameState = Dict[Any, Any]
|
||||
|
||||
|
|
@ -37,6 +39,8 @@ class GuardFn(Protocol):
|
|||
verbose_code_parts: List[str]
|
||||
global_scope: Dict[str, object]
|
||||
guard_fail_fn: Optional[Callable[[GuardFail], None]]
|
||||
cache_entry: Optional[CacheEntry]
|
||||
extra_state: Optional[ExtraState]
|
||||
|
||||
# maps locals of user function to bool
|
||||
def __call__(self, f_locals: Dict[str, object]) -> bool:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,12 @@ CacheEntry::CacheEntry(const py::handle& guarded_code) {
|
|||
this->code = guarded_code.attr("code");
|
||||
}
|
||||
|
||||
CacheEntry::~CacheEntry() {
|
||||
// prevent check_fn from use-after-free when invalidating
|
||||
this->check_fn.attr("cache_entry") = py::none();
|
||||
this->check_fn.attr("extra_state") = py::none();
|
||||
}
|
||||
|
||||
py::object CacheEntry::next() {
|
||||
NULL_CHECK(this->_owner);
|
||||
auto it = this->_owner_loc;
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ typedef struct VISIBILITY_HIDDEN CacheEntry {
|
|||
std::list<CacheEntry>::iterator _owner_loc;
|
||||
|
||||
CacheEntry(const py::handle& guarded_code);
|
||||
~CacheEntry();
|
||||
|
||||
// Warning: returns a reference whose lifetime is controlled by C++
|
||||
py::object next();
|
||||
|
|
|
|||
|
|
@ -22,6 +22,13 @@ void ExtraState::move_to_front(CacheEntry* cache_entry) {
|
|||
cache_entry->_owner_loc);
|
||||
}
|
||||
|
||||
void ExtraState::invalidate(CacheEntry* cache_entry) {
|
||||
CHECK(cache_entry->_owner == this);
|
||||
CHECK(!this->cache_entry_list.empty());
|
||||
CHECK(cache_entry == &*cache_entry->_owner_loc);
|
||||
this->cache_entry_list.erase(cache_entry->_owner_loc);
|
||||
}
|
||||
|
||||
CacheEntry* extract_cache_entry(ExtraState* extra_state) {
|
||||
if (extra_state == NULL || extra_state == SKIP_CODE) {
|
||||
return NULL;
|
||||
|
|
@ -110,6 +117,13 @@ CacheEntry* create_cache_entry(
|
|||
auto new_iter = extra_state->cache_entry_list.begin();
|
||||
new_iter->_owner = extra_state;
|
||||
new_iter->_owner_loc = new_iter;
|
||||
// Set check_fn references to extra_state and CacheEntry
|
||||
// Warning: lifetime is controlled by C++!
|
||||
py::handle check_fn = py::handle(guarded_code).attr("check_fn");
|
||||
check_fn.attr("cache_entry") =
|
||||
py::cast(*new_iter, py::return_value_policy::reference);
|
||||
check_fn.attr("extra_state") =
|
||||
py::cast(extra_state, py::return_value_policy::reference);
|
||||
return &*new_iter;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ typedef struct VISIBILITY_HIDDEN ExtraState {
|
|||
|
||||
CacheEntry* get_first_entry();
|
||||
void move_to_front(CacheEntry* cache_entry);
|
||||
void invalidate(CacheEntry* cache_entry);
|
||||
} ExtraState;
|
||||
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -44,6 +44,9 @@ void initDynamoBindings(PyObject* torch) {
|
|||
.def_readonly("code", &CacheEntry::code)
|
||||
.def_property_readonly("next", &CacheEntry::next);
|
||||
|
||||
py::class_<ExtraState>(m, "_ExtraState")
|
||||
.def("invalidate", &ExtraState::invalidate);
|
||||
|
||||
m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user