Cache output tensors on execution (#98944)

Caches output tensors for the common case when the output Tensor storage is unaliased for all graph outputs in all paths. For these persisted tensors we adjust the liveness tracking by also checking that the output tensor does not have an additional python reference.

I limit cached output tensors to be unaliased. If a descendent node discovers it has an alias of a prior output, then the aliased output will no longer be persisted in the ancestor.

The large majority of tensors are unaliased, and preserving aliased output tensors would add significant additional complexity with marginal gains. For instance, when do checkpointing and re-recordings, we need to remove the persisted tensors otherwise it would prevent memory from being reclaimed. If a single persisted tensor was present in multiple paths then that would create an inter-path dependence which adds complexity. Additionally, each further caching of the output would affect the reference count of the other caches, and that reference count would also need to be adjusted depending on if a node was checkpointed.

Still need to do a complete a run but for the models I tried makes the performance extremely close between trees and non trees impl.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98944
Approved by: https://github.com/jansel, https://github.com/ngimel
This commit is contained in:
Elias Ellison 2023-04-17 21:44:00 +00:00 committed by PyTorch MergeBot
parent 93b64f0ad3
commit 472f46635e
5 changed files with 301 additions and 60 deletions

View File

@ -100,10 +100,17 @@ if HAS_CUDA and not TEST_WITH_ASAN:
class CudaGraphTreeTests(TestCase): class CudaGraphTreeTests(TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.prev_enabled = config.triton.cudagraphs self.graph_stack = contextlib.ExitStack()
self.tapes_enabled = config.triton.cudagraph_trees self.graph_stack.enter_context(
config.triton.cudagraphs = True config.patch(
config.triton.cudagraph_trees = True {
"triton.cudagraphs": True,
"triton.cudagraph_trees": True,
"triton.fast_path_cudagraph_asserts": True, # too slow
"triton.slow_path_cudagraph_asserts": True,
}
)
)
self.device_idx = torch.rand([0], device="cuda").device.index self.device_idx = torch.rand([0], device="cuda").device.index
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -111,10 +118,12 @@ if HAS_CUDA and not TEST_WITH_ASAN:
super().tearDown() super().tearDown()
torch._dynamo.reset() torch._dynamo.reset()
gc.collect() gc.collect()
config.triton.cudagraphs = self.prev_enabled torch.cuda.empty_cache()
config.triton.cudagraph_trees = self.tapes_enabled self.graph_stack.close()
self.assertIsNone(self.get_manager()) self.assertIsNone(self.get_manager())
self.assertEqual(all_live_block_count(), 0) self.assertEqual(all_live_block_count(), 0)
self.assertEqual(len(get_all_cudagraph_segments()), 0)
warnings.resetwarnings() warnings.resetwarnings()
def get_manager(self, device_index=None): def get_manager(self, device_index=None):
@ -463,7 +472,11 @@ if HAS_CUDA and not TEST_WITH_ASAN:
ptr_to_ref[out.untyped_storage().data_ptr()], ptr_to_ref[out.untyped_storage().data_ptr()],
out.untyped_storage()._cdata, out.untyped_storage()._cdata,
) )
del outs
del out
node = self.get_manager().current_node
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
self.assertFalse(self.get_manager().new_graph_id().id == 0) self.assertFalse(self.get_manager().new_graph_id().id == 0)
def test_aliasing_static_ref(self): def test_aliasing_static_ref(self):
@ -484,15 +497,63 @@ if HAS_CUDA and not TEST_WITH_ASAN:
x = torch.rand([10, 10], device="cuda", requires_grad=True) x = torch.rand([10, 10], device="cuda", requires_grad=True)
param_c = cdata(m.weight) param_c = cdata(m.weight)
for _ in range(3): for _ in range(3):
# print("Runnng foo")
out1, alias_1, alias_2 = foo(m, x) out1, alias_1, alias_2 = foo(m, x)
self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1) self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1)
# print("Runnng foo2")
out2 = foo2(out1) out2 = foo2(out1)
out2.sum().backward() out2.sum().backward()
self.assertEqual(cdata(out1), cdata(out2)) self.assertEqual(cdata(out1), cdata(out2))
node = self.curr_node()
first_node = next(node._path_from_root)
self.assertFalse(first_node.unaliased_in_all_paths[0])
self.assertTrue(first_node.cached_tensor_outputs[0] is None)
def test_checkpointing_resets_persistent_refs(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
return x @ x
def inp():
return torch.rand([20, 20], device="cuda", requires_grad=False)
for _ in range(3):
foo(inp())
self.assertEqual(self.num_checkpoints(), 0)
out = foo(inp())
out_id = id(out)
del out
self.assertEqual(id(foo(inp())), out_id)
@torch.compile(mode="reduce-overhead")
def foo2(x):
return x[0], x @ x
for i in range(2):
out = foo(inp())
from torch._dynamo.mutation_guard import GenerationTracker
GenerationTracker.generation -= 1
out_alias, out2 = foo2(out)
del out_alias
self.assertEqual(all_live_block_count(), 2)
del out
self.assertEqual(all_live_block_count(), 1)
del out2
self.assertEqual(all_live_block_count(), 0)
self.assertEqual(self.num_checkpoints(), i + 1)
new_out = foo(inp())
curr_node = self.curr_node()
self.assertFalse(curr_node.unaliased_in_all_paths[0])
self.assertFalse(out_id == id(new_out))
def test_aliased_static_parameter(self): def test_aliased_static_parameter(self):
inp = torch.rand([20, 20], device="cuda") inp = torch.rand([20, 20], device="cuda")
@ -507,6 +568,29 @@ if HAS_CUDA and not TEST_WITH_ASAN:
out = foo_cg([inp])[0] out = foo_cg([inp])[0]
self.assertEqual(cdata(inp), cdata(out)) self.assertEqual(cdata(inp), cdata(out))
node = self.curr_node()
self.assertEqual(node.cached_tensor_outputs, [None])
self.assertEqual(node.unaliased_in_all_paths, [False])
def test_output_alias(self):
inp = torch.rand([20, 20], device="cuda")
def foo(args):
x = args[0]
args.clear()
out = x + x
return (x, x[0])
foo_cg = self.cudagraphify_impl(foo, [inp])
for _ in range(3):
out_1, out_2 = foo_cg([inp])
self.assertEqual(cdata(out_1), cdata(out_2))
del out_1, out_2
self.assertEqual(len(list(self.curr_node().path_live_weakrefs())), 0)
self.assertEqual(self.curr_node().cached_tensor_outputs, [None, None])
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True) @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_aliased_output_checkpoint(self): def test_aliased_output_checkpoint(self):
def foo(args): def foo(args):
@ -543,6 +627,28 @@ if HAS_CUDA and not TEST_WITH_ASAN:
del x del x
self.assertEqual(all_live_block_count(), 0) self.assertEqual(all_live_block_count(), 0)
def test_peristed_output_livenes(self):
@torch.compile
def foo(x):
return x + x
for _ in range(3):
foo(torch.rand([2, 2], device="cuda"))
node = self.get_manager().current_node
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
out = foo(torch.rand([2, 2], device="cuda"))
self.assertTrue(out is node.cached_tensor_outputs[0])
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
out_ref = out[0:]
del out
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
del out_ref
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True) @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_tensor_no_longer_in_pool(self): def test_tensor_no_longer_in_pool(self):
def foo(args): def foo(args):
@ -771,11 +877,13 @@ if HAS_CUDA and not TEST_WITH_ASAN:
inp = torch.rand([4], device="cuda", requires_grad=True) inp = torch.rand([4], device="cuda", requires_grad=True)
streams = set() streams = set()
streams_init = {seg["stream"] for seg in get_all_cudagraph_segments()}
for _ in range(3): for _ in range(4):
foo(inp).sum().backward() foo(inp).sum().backward()
streams = {seg["stream"] for seg in get_all_cudagraph_segments()} streams = {
seg["stream"] for seg in get_all_cudagraph_segments()
} - streams_init
self.assertEqual(len(streams), 1) self.assertEqual(len(streams), 1)
self.assertFalse(self.get_manager().new_graph_id().id == 0) self.assertFalse(self.get_manager().new_graph_id().id == 0)

View File

@ -1529,6 +1529,12 @@ class _TorchCompileInductorWrapper:
return compile_fx(model_, inputs_, config_patches=self.config) return compile_fx(model_, inputs_, config_patches=self.config)
def reset(self):
from torch._inductor import config
if "triton.cudagraphs" in self.config or config.triton.cudagraphs:
if self.config.get("triton.cudagraphs", True):
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
reset_cudagraph_trees()
def compile(model: Optional[Callable] = None, *, def compile(model: Optional[Callable] = None, *,
fullgraph: builtins.bool = False, fullgraph: builtins.bool = False,

View File

@ -51,6 +51,8 @@ def reset():
orig_code_map.clear() orig_code_map.clear()
guard_failures.clear() guard_failures.clear()
resume_execution.ContinueExecutionCache.cache.clear() resume_execution.ContinueExecutionCache.cache.clear()
if hasattr(eval_frame.most_recent_backend, "reset"):
eval_frame.most_recent_backend.reset()
eval_frame.most_recent_backend = None eval_frame.most_recent_backend = None
compilation_metrics.clear() compilation_metrics.clear()
reset_frame_count() reset_frame_count()

View File

@ -41,6 +41,7 @@ import dataclasses
import functools import functools
import gc import gc
import itertools import itertools
import sys
import threading import threading
import warnings import warnings
import weakref import weakref
@ -245,6 +246,22 @@ torch._C._stash_obj_in_tls("tree_manager_containers", local.tree_manager_contain
torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks) torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks)
def reset_cudagraph_trees():
"Clear all cudagraph trees"
# see remove_all_cached_tensors below for why this is necessary
container_dict = get_obj(local, "tree_manager_containers")
locks_dict = get_obj(local, "tree_manager_locks")
for device, lock in locks_dict.items():
with lock:
container = container_dict.get(device)
if not container or not container.tree_manager:
continue
container.tree_manager.remove_all_cached_tensors()
container_dict.clear()
def get_obj(local, attr_name): def get_obj(local, attr_name):
if hasattr(local, attr_name): if hasattr(local, attr_name):
return getattr(local, attr_name) return getattr(local, attr_name)
@ -324,11 +341,15 @@ class StorageWeakRefWrapper:
Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked. Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked.
""" """
__slots__ = ["ref", "_data_ptr"] __slots__ = ["ref", "_data_ptr", "extra_ref_check"]
storage_ref: Optional[StorageWeakRef] storage_ref: Optional[StorageWeakRef]
def __init__(self, inp: Union[Tensor, UntypedStorage]): def __init__(
self,
inp: Union[Tensor, UntypedStorage],
extra_ref_check: Optional[Callable[[], None]] = None,
):
if isinstance(inp, Tensor): if isinstance(inp, Tensor):
stor = inp.untyped_storage() stor = inp.untyped_storage()
else: else:
@ -336,28 +357,41 @@ class StorageWeakRefWrapper:
stor = inp stor = inp
self.ref = StorageWeakRef(stor) self.ref = StorageWeakRef(stor)
self._data_ptr = stor.data_ptr() self._data_ptr = stor.data_ptr()
self.extra_ref_check = extra_ref_check
@classmethod @classmethod
def from_weakref_and_data_ptr(cls, cdata, data_ptr): def from_weakref_and_data_ptr(cls, cdata, data_ptr, extra_ref_check=None):
instance = cls.__new__(cls) instance = cls.__new__(cls)
instance._data_ptr = data_ptr instance._data_ptr = data_ptr
instance.ref = StorageWeakRef.from_weakref(cdata) instance.ref = StorageWeakRef.from_weakref(cdata)
instance.extra_ref_check = extra_ref_check
return instance return instance
def __call__(self) -> Optional[StorageWeakRefPointer]: def __call__(self) -> Optional[StorageWeakRefPointer]:
if self.ref is None: if self.expired():
return None
if self.ref.expired():
self.ref = None
return None return None
return self.ref.cdata return self.ref.cdata
def swap_weakref(self, cdata):
self.ref.__del__()
self.ref.cdata = cdata
def data_ptr(self) -> int: def data_ptr(self) -> int:
"NB: returns the data ptr even if the storage has expired" "NB: returns the data ptr even if the storage has expired"
return self._data_ptr return self._data_ptr
def remove_extra_reference(self):
self.extra_ref_check = None
def expired(self):
if self.extra_ref_check is not None and not self.extra_ref_check():
return False
# if extra_ref_check is not None we expect an additional reference
stor_count = torch._C._storage_Use_Count(self.ref.cdata)
return (stor_count - (self.extra_ref_check is not None)) == 0
def __repr__(self): def __repr__(self):
if self.ref is None or self.ref.expired(): if self.ref is None or self.ref.expired():
return f"StorageWeakRefWrapper to {self.data_ptr()}; dead" return f"StorageWeakRefWrapper to {self.data_ptr()}; dead"
@ -725,6 +759,18 @@ class CUDAGraphNode:
# - An alias of an output already created in the reconstructed outputs # - An alias of an output already created in the reconstructed outputs
self.output_storage_alias: OutputList[OutputAliasInfo] = [] self.output_storage_alias: OutputList[OutputAliasInfo] = []
# is the output Storage unaliased in subsequent outputs, of all subsequent paths
# if it is, we cached the output tensor and adjust storage liveness tracking to also
# check if the output tensor does not have an additional python reference.
# If a descendent node discovers it has an alias of a prior output, then the output
# will no longer be cached in the ancestor.
# The large majority of tensors are unaliased, and preserving aliased output tensors would add
# significant additional complexity with marginal gains
# The cached tensor outputs are added on the first execution, and cleared whenever we need
# to do subsequent recording
self.unaliased_in_all_paths: OutputList[bool] = []
self.cached_tensor_outputs: OutputList[Optional[Tensor]] = []
# if an output aliases a static, persistent input then the Storage of the # if an output aliases a static, persistent input then the Storage of the
# persistent output will be set here # persistent output will be set here
self.output_persistent_storage: OutputList[Optional[UntypedStorage]] = [] self.output_persistent_storage: OutputList[Optional[UntypedStorage]] = []
@ -788,8 +834,6 @@ class CUDAGraphNode:
self.run_graph() self.run_graph()
outputs = self.reconstruct_outputs() outputs = self.reconstruct_outputs()
self._add_replayed_outputs(outputs)
self.debug_check_invariants_after_invocation() self.debug_check_invariants_after_invocation()
return outputs return outputs
@ -797,35 +841,44 @@ class CUDAGraphNode:
def reconstruct_outputs(self): def reconstruct_outputs(self):
"Reconstruct output tensors according to their saved metadata and alias information" "Reconstruct output tensors according to their saved metadata and alias information"
# The cpp function is constructing a new Tensor according to the saved output metadata # Cached tensors will not yet be set on the first execution
# For each element in the corresponding storage list: # They are also cleared in checkpointing, so if we checkpoint this node
# - if a Storage is contained, that will be used # and then execute it again we will need to repopulate cached tensors
# - if None is contained, a new Storage will be constructed if not self.cached_tensor_outputs:
# - if an int is contained, the storage from the output list at that int will be used self._initialize_cached_tensors()
storages_info: List[
Union[UntypedStorage, None, int]
] = self.prepare_storages_for_construction()
outputs_new = []
# # We recreate the below logic in cpp to reduce overhead, since this is on the hot path outputs = []
for storage_info, metadata in zip(storages_info, self.outputs_metadata):
for i, (storage_info, metadata) in enumerate(
zip(self.output_storage_alias, self.outputs_metadata)
):
if metadata is None: if metadata is None:
outputs_new.append(None) outputs.append(None)
continue continue
if storage_info is None: cached_t = self.cached_tensor_outputs[i]
s = self.create_storage(metadata) if cached_t is not None:
elif isinstance(storage_info, UntypedStorage): # No need to update weakrefs, already correctly initialized
s = storage_info outputs.append(cached_t)
else: continue
assert isinstance(storage_info, int)
s = outputs_new[storage_info].untyped_storage()
outputs_new.append( storage = self.prepare_alias_info_for_tensor_construction(
self._reconstruct_from_tensor_metadata(metadata, storage=s) i, storage_info, metadata
) )
return outputs_new if isinstance(storage, UntypedStorage) or storage is None:
out = self._reconstruct_from_tensor_metadata(metadata, storage)
else:
assert isinstance(storage, int)
out = self._reconstruct_from_tensor_metadata(
metadata, outputs[storage].untyped_storage()
)
outputs.append(out)
if storage_info is not PersistentStaticStorage:
self.outputs_weakrefs[i].swap_weakref(out.untyped_storage()._weak_ref())
return outputs
def prepare_alias_info_for_tensor_construction( def prepare_alias_info_for_tensor_construction(
self, out_index: int, out_alias_info: OutputAliasInfo, metadata: Dict[str, Any] self, out_index: int, out_alias_info: OutputAliasInfo, metadata: Dict[str, Any]
@ -839,7 +892,6 @@ class CUDAGraphNode:
if isinstance(out_alias_info, AliasesPriorGraphOutput): if isinstance(out_alias_info, AliasesPriorGraphOutput):
depth, existing_output_index = out_alias_info.index depth, existing_output_index = out_alias_info.index
ref = self.path_weakrefs[depth][existing_output_index] ref = self.path_weakrefs[depth][existing_output_index]
assert ref()
return torch.UntypedStorage._new_with_weak_ptr(ref()) return torch.UntypedStorage._new_with_weak_ptr(ref())
assert isinstance(out_alias_info, AliasesNewOutput) assert isinstance(out_alias_info, AliasesNewOutput)
@ -928,6 +980,8 @@ class CUDAGraphNode:
# index from data pointer to index in outputs # index from data pointer to index in outputs
output_new_storages_index: Dict[StorageDataPtr, int] = {} output_new_storages_index: Dict[StorageDataPtr, int] = {}
self.unaliased_in_all_paths = [False for _ in range(len(outputs))]
for i, o in enumerate(outputs): for i, o in enumerate(outputs):
if o is None: if o is None:
self.output_storage_alias.append(UnaliasedStorage) self.output_storage_alias.append(UnaliasedStorage)
@ -948,19 +1002,19 @@ class CUDAGraphNode:
path_ref = self._is_alias_of_live_recorded_tensor(o) path_ref = self._is_alias_of_live_recorded_tensor(o)
if path_ref is not None: if path_ref is not None:
self._mark_prior_graph_output_as_aliased(path_ref)
self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref)) self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref))
continue continue
if o.untyped_storage().data_ptr() in output_new_storages_index: if o.untyped_storage().data_ptr() in output_new_storages_index:
self.output_storage_alias.append( index = output_new_storages_index[o.untyped_storage().data_ptr()]
AliasesNewOutput( self.unaliased_in_all_paths[index] = False
output_new_storages_index[o.untyped_storage().data_ptr()] self.output_storage_alias.append(AliasesNewOutput(index))
)
)
continue continue
output_new_storages_index[o.untyped_storage().data_ptr()] = i output_new_storages_index[o.untyped_storage().data_ptr()] = i
self.output_storage_alias.append(UnaliasedStorage) self.output_storage_alias.append(UnaliasedStorage)
self.unaliased_in_all_paths[i] = True
if self.stack_traces is None: if self.stack_traces is None:
self.stack_traces = [None for _ in range(len(outputs))] self.stack_traces = [None for _ in range(len(outputs))]
@ -969,9 +1023,14 @@ class CUDAGraphNode:
outputs outputs
), "Wrong number of stack traces passed in" ), "Wrong number of stack traces passed in"
self._add_replayed_outputs(outputs) assert not self.outputs_weakrefs
self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs) for out, persisted_storage in zip(outputs, self.output_persistent_storage):
if out is None or persisted_storage is not None:
self.outputs_weakrefs.append(None)
else:
self.outputs_weakrefs.append(StorageWeakRefWrapper(out))
self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs)
self.checkpointed_caching_state = torch._C._cuda_getCheckpointState( self.checkpointed_caching_state = torch._C._cuda_getCheckpointState(
self.device, self.cuda_graphs_pool self.device, self.cuda_graphs_pool
) )
@ -986,15 +1045,46 @@ class CUDAGraphNode:
if config.triton.slow_path_cudagraph_asserts: if config.triton.slow_path_cudagraph_asserts:
check_memory_pool(self.cuda_graphs_pool, list(self.path_live_weakrefs())) check_memory_pool(self.cuda_graphs_pool, list(self.path_live_weakrefs()))
def _add_replayed_outputs(self, outputs): def _mark_prior_graph_output_as_aliased(self, index: PathOutputIndex):
self.outputs_weakrefs.clear() "Remove a graph output from the unaliased, cached tensors in an ancestor node"
depth, output_index = index
node = list(self._path_from_root)[depth]
node.unaliased_in_all_paths[output_index] = False
self.path_weakrefs[depth][output_index].remove_extra_reference()
for out, persistent_storage in zip(outputs, self.output_persistent_storage): def _initialize_cached_tensors(self):
if out is None or persistent_storage is not None: # we should not be clearing output_weakrefs, and they should be set in the first
self.outputs_weakrefs.append(None) # record run
assert len(self.outputs_weakrefs) == len(self.outputs_metadata)
for i, (storage_info, metadata, make_cached) in enumerate(
zip(
self.output_storage_alias,
self.outputs_metadata,
self.unaliased_in_all_paths,
)
):
if not make_cached:
self.cached_tensor_outputs.append(None)
continue continue
self.outputs_weakrefs.append(StorageWeakRefWrapper(out)) assert storage_info is UnaliasedStorage
s = self.create_storage(metadata)
out = self._reconstruct_from_tensor_metadata(metadata, storage=s)
self_ref = weakref.ref(self)
# one reference in our array, and calling sys.getrefcount bumps the refcount by one
def check_refcount(i):
return self_ref().get_output_refcount(i) == 2
check = functools.partial(check_refcount, i=i)
self.outputs_weakrefs[i] = StorageWeakRefWrapper(out, extra_ref_check=check)
self.cached_tensor_outputs.append(out)
def get_output_refcount(self, index):
return sys.getrefcount(self.cached_tensor_outputs[index])
@property @property
def parent(self): def parent(self):
@ -1156,6 +1246,16 @@ class CUDAGraphNode:
if is_live(out): if is_live(out):
yield out yield out
def remove_node_cached_tensors(self):
self.cached_tensor_outputs.clear()
for i, unaliased in enumerate(self.unaliased_in_all_paths):
if unaliased:
self.outputs_weakrefs[i].remove_extra_reference()
def remove_path_cached_tensors(self):
for node in self._path_from_root:
node.remove_node_cached_tensors()
def path_live_weakrefs_and_stacktraces( def path_live_weakrefs_and_stacktraces(
self, self,
) -> Generator[Tuple[StorageWeakRefWrapper, Optional[str]]]: ) -> Generator[Tuple[StorageWeakRefWrapper, Optional[str]]]:
@ -1166,9 +1266,9 @@ class CUDAGraphNode:
yield out, self.path_stacktraces[i][j] yield out, self.path_stacktraces[i][j]
def clear_path_state(self): def clear_path_state(self):
"Clear the output lists of all nodes in the path and the storage cache" "Clear the path state in this current executing node"
for li in self.path_weakrefs: # this doesnt actually do anything right now, leaving it as placeholder
li.clear() pass
@staticmethod @staticmethod
def _tensor_metadata(x, ignore_storage_offset=True): def _tensor_metadata(x, ignore_storage_offset=True):
@ -1520,6 +1620,22 @@ class CUDAGraphTreeManager:
# now, we are in a recording state ! # now, we are in a recording state !
return self.record_function(new_inputs, function_id) return self.record_function(new_inputs, function_id)
def remove_all_cached_tensors(self):
"""
Remove all cached tensors in all nodes. Because cached tensors can hold gradients which in turn
might reference a backward which invokes a CUDA Graph Node, we have to manually clear them on shutdown
to avoid a reference cycle.
"""
nodes = []
for roots in self.roots.values():
nodes.extend(roots)
while nodes:
node = nodes.pop()
for children in node.children.values():
nodes.extend(children)
node.remove_node_cached_tensors()
def record_function(self, new_inputs, function_id) -> List[Optional[Tensor]]: def record_function(self, new_inputs, function_id) -> List[Optional[Tensor]]:
torch.cuda.synchronize() torch.cuda.synchronize()
node = CUDAGraphNode( node = CUDAGraphNode(
@ -1708,6 +1824,10 @@ class CUDAGraphTreeManager:
# currently we deallocate on instead of allowing stale recordings # currently we deallocate on instead of allowing stale recordings
stale_storages = [] stale_storages = []
# remove cached tensors, otherwise they would prevent memory from being
# reclaimed in subsequent recordings
self.current_node.remove_path_cached_tensors()
live_storages_wrappers = list(self.current_node.path_live_weakrefs()) live_storages_wrappers = list(self.current_node.path_live_weakrefs())
live_storages_weak_refs = [t() for t in live_storages_wrappers] live_storages_weak_refs = [t() for t in live_storages_wrappers]

View File

@ -1089,6 +1089,11 @@ static void registerCudaPluggableAllocator(PyObject* module) {
return (storage_impl->data_ptr().get_deleter() == alloc->raw_deleter()); return (storage_impl->data_ptr().get_deleter() == alloc->raw_deleter());
}); });
m.def("_storage_Use_Count", [](size_t storage_impl_ptr) {
c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr;
return c10::raw::weak_intrusive_ptr::use_count(storage_impl);
});
m.def( m.def(
"_construct_CUDA_Tensor_From_Storage_And_Metadata", "_construct_CUDA_Tensor_From_Storage_And_Metadata",
[](py::dict& metadata, c10::Storage s) { [](py::dict& metadata, c10::Storage s) {