Cache output tensors on execution (#98944)

Caches output tensors for the common case when the output Tensor storage is unaliased for all graph outputs in all paths. For these persisted tensors we adjust the liveness tracking by also checking that the output tensor does not have an additional python reference.

I limit cached output tensors to be unaliased. If a descendent node discovers it has an alias of a prior output, then the aliased output will no longer be persisted in the ancestor.

The large majority of tensors are unaliased, and preserving aliased output tensors would add significant additional complexity with marginal gains. For instance, when do checkpointing and re-recordings, we need to remove the persisted tensors otherwise it would prevent memory from being reclaimed. If a single persisted tensor was present in multiple paths then that would create an inter-path dependence which adds complexity. Additionally, each further caching of the output would affect the reference count of the other caches, and that reference count would also need to be adjusted depending on if a node was checkpointed.

Still need to do a complete a run but for the models I tried makes the performance extremely close between trees and non trees impl.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98944
Approved by: https://github.com/jansel, https://github.com/ngimel
This commit is contained in:
Elias Ellison 2023-04-17 21:44:00 +00:00 committed by PyTorch MergeBot
parent 93b64f0ad3
commit 472f46635e
5 changed files with 301 additions and 60 deletions

View File

@ -100,10 +100,17 @@ if HAS_CUDA and not TEST_WITH_ASAN:
class CudaGraphTreeTests(TestCase):
def setUp(self):
super().setUp()
self.prev_enabled = config.triton.cudagraphs
self.tapes_enabled = config.triton.cudagraph_trees
config.triton.cudagraphs = True
config.triton.cudagraph_trees = True
self.graph_stack = contextlib.ExitStack()
self.graph_stack.enter_context(
config.patch(
{
"triton.cudagraphs": True,
"triton.cudagraph_trees": True,
"triton.fast_path_cudagraph_asserts": True, # too slow
"triton.slow_path_cudagraph_asserts": True,
}
)
)
self.device_idx = torch.rand([0], device="cuda").device.index
warnings.filterwarnings("ignore")
@ -111,10 +118,12 @@ if HAS_CUDA and not TEST_WITH_ASAN:
super().tearDown()
torch._dynamo.reset()
gc.collect()
config.triton.cudagraphs = self.prev_enabled
config.triton.cudagraph_trees = self.tapes_enabled
torch.cuda.empty_cache()
self.graph_stack.close()
self.assertIsNone(self.get_manager())
self.assertEqual(all_live_block_count(), 0)
self.assertEqual(len(get_all_cudagraph_segments()), 0)
warnings.resetwarnings()
def get_manager(self, device_index=None):
@ -463,7 +472,11 @@ if HAS_CUDA and not TEST_WITH_ASAN:
ptr_to_ref[out.untyped_storage().data_ptr()],
out.untyped_storage()._cdata,
)
del outs
del out
node = self.get_manager().current_node
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
self.assertFalse(self.get_manager().new_graph_id().id == 0)
def test_aliasing_static_ref(self):
@ -484,15 +497,63 @@ if HAS_CUDA and not TEST_WITH_ASAN:
x = torch.rand([10, 10], device="cuda", requires_grad=True)
param_c = cdata(m.weight)
for _ in range(3):
# print("Runnng foo")
out1, alias_1, alias_2 = foo(m, x)
self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1)
# print("Runnng foo2")
out2 = foo2(out1)
out2.sum().backward()
self.assertEqual(cdata(out1), cdata(out2))
node = self.curr_node()
first_node = next(node._path_from_root)
self.assertFalse(first_node.unaliased_in_all_paths[0])
self.assertTrue(first_node.cached_tensor_outputs[0] is None)
def test_checkpointing_resets_persistent_refs(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
return x @ x
def inp():
return torch.rand([20, 20], device="cuda", requires_grad=False)
for _ in range(3):
foo(inp())
self.assertEqual(self.num_checkpoints(), 0)
out = foo(inp())
out_id = id(out)
del out
self.assertEqual(id(foo(inp())), out_id)
@torch.compile(mode="reduce-overhead")
def foo2(x):
return x[0], x @ x
for i in range(2):
out = foo(inp())
from torch._dynamo.mutation_guard import GenerationTracker
GenerationTracker.generation -= 1
out_alias, out2 = foo2(out)
del out_alias
self.assertEqual(all_live_block_count(), 2)
del out
self.assertEqual(all_live_block_count(), 1)
del out2
self.assertEqual(all_live_block_count(), 0)
self.assertEqual(self.num_checkpoints(), i + 1)
new_out = foo(inp())
curr_node = self.curr_node()
self.assertFalse(curr_node.unaliased_in_all_paths[0])
self.assertFalse(out_id == id(new_out))
def test_aliased_static_parameter(self):
inp = torch.rand([20, 20], device="cuda")
@ -507,6 +568,29 @@ if HAS_CUDA and not TEST_WITH_ASAN:
out = foo_cg([inp])[0]
self.assertEqual(cdata(inp), cdata(out))
node = self.curr_node()
self.assertEqual(node.cached_tensor_outputs, [None])
self.assertEqual(node.unaliased_in_all_paths, [False])
def test_output_alias(self):
inp = torch.rand([20, 20], device="cuda")
def foo(args):
x = args[0]
args.clear()
out = x + x
return (x, x[0])
foo_cg = self.cudagraphify_impl(foo, [inp])
for _ in range(3):
out_1, out_2 = foo_cg([inp])
self.assertEqual(cdata(out_1), cdata(out_2))
del out_1, out_2
self.assertEqual(len(list(self.curr_node().path_live_weakrefs())), 0)
self.assertEqual(self.curr_node().cached_tensor_outputs, [None, None])
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_aliased_output_checkpoint(self):
def foo(args):
@ -543,6 +627,28 @@ if HAS_CUDA and not TEST_WITH_ASAN:
del x
self.assertEqual(all_live_block_count(), 0)
def test_peristed_output_livenes(self):
@torch.compile
def foo(x):
return x + x
for _ in range(3):
foo(torch.rand([2, 2], device="cuda"))
node = self.get_manager().current_node
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
out = foo(torch.rand([2, 2], device="cuda"))
self.assertTrue(out is node.cached_tensor_outputs[0])
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
out_ref = out[0:]
del out
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
del out_ref
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_tensor_no_longer_in_pool(self):
def foo(args):
@ -771,11 +877,13 @@ if HAS_CUDA and not TEST_WITH_ASAN:
inp = torch.rand([4], device="cuda", requires_grad=True)
streams = set()
for _ in range(3):
streams_init = {seg["stream"] for seg in get_all_cudagraph_segments()}
for _ in range(4):
foo(inp).sum().backward()
streams = {seg["stream"] for seg in get_all_cudagraph_segments()}
streams = {
seg["stream"] for seg in get_all_cudagraph_segments()
} - streams_init
self.assertEqual(len(streams), 1)
self.assertFalse(self.get_manager().new_graph_id().id == 0)

View File

@ -1529,6 +1529,12 @@ class _TorchCompileInductorWrapper:
return compile_fx(model_, inputs_, config_patches=self.config)
def reset(self):
from torch._inductor import config
if "triton.cudagraphs" in self.config or config.triton.cudagraphs:
if self.config.get("triton.cudagraphs", True):
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
reset_cudagraph_trees()
def compile(model: Optional[Callable] = None, *,
fullgraph: builtins.bool = False,

View File

@ -51,6 +51,8 @@ def reset():
orig_code_map.clear()
guard_failures.clear()
resume_execution.ContinueExecutionCache.cache.clear()
if hasattr(eval_frame.most_recent_backend, "reset"):
eval_frame.most_recent_backend.reset()
eval_frame.most_recent_backend = None
compilation_metrics.clear()
reset_frame_count()

View File

@ -41,6 +41,7 @@ import dataclasses
import functools
import gc
import itertools
import sys
import threading
import warnings
import weakref
@ -245,6 +246,22 @@ torch._C._stash_obj_in_tls("tree_manager_containers", local.tree_manager_contain
torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks)
def reset_cudagraph_trees():
"Clear all cudagraph trees"
# see remove_all_cached_tensors below for why this is necessary
container_dict = get_obj(local, "tree_manager_containers")
locks_dict = get_obj(local, "tree_manager_locks")
for device, lock in locks_dict.items():
with lock:
container = container_dict.get(device)
if not container or not container.tree_manager:
continue
container.tree_manager.remove_all_cached_tensors()
container_dict.clear()
def get_obj(local, attr_name):
if hasattr(local, attr_name):
return getattr(local, attr_name)
@ -324,11 +341,15 @@ class StorageWeakRefWrapper:
Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked.
"""
__slots__ = ["ref", "_data_ptr"]
__slots__ = ["ref", "_data_ptr", "extra_ref_check"]
storage_ref: Optional[StorageWeakRef]
def __init__(self, inp: Union[Tensor, UntypedStorage]):
def __init__(
self,
inp: Union[Tensor, UntypedStorage],
extra_ref_check: Optional[Callable[[], None]] = None,
):
if isinstance(inp, Tensor):
stor = inp.untyped_storage()
else:
@ -336,28 +357,41 @@ class StorageWeakRefWrapper:
stor = inp
self.ref = StorageWeakRef(stor)
self._data_ptr = stor.data_ptr()
self.extra_ref_check = extra_ref_check
@classmethod
def from_weakref_and_data_ptr(cls, cdata, data_ptr):
def from_weakref_and_data_ptr(cls, cdata, data_ptr, extra_ref_check=None):
instance = cls.__new__(cls)
instance._data_ptr = data_ptr
instance.ref = StorageWeakRef.from_weakref(cdata)
instance.extra_ref_check = extra_ref_check
return instance
def __call__(self) -> Optional[StorageWeakRefPointer]:
if self.ref is None:
return None
if self.ref.expired():
self.ref = None
if self.expired():
return None
return self.ref.cdata
def swap_weakref(self, cdata):
self.ref.__del__()
self.ref.cdata = cdata
def data_ptr(self) -> int:
"NB: returns the data ptr even if the storage has expired"
return self._data_ptr
def remove_extra_reference(self):
self.extra_ref_check = None
def expired(self):
if self.extra_ref_check is not None and not self.extra_ref_check():
return False
# if extra_ref_check is not None we expect an additional reference
stor_count = torch._C._storage_Use_Count(self.ref.cdata)
return (stor_count - (self.extra_ref_check is not None)) == 0
def __repr__(self):
if self.ref is None or self.ref.expired():
return f"StorageWeakRefWrapper to {self.data_ptr()}; dead"
@ -725,6 +759,18 @@ class CUDAGraphNode:
# - An alias of an output already created in the reconstructed outputs
self.output_storage_alias: OutputList[OutputAliasInfo] = []
# is the output Storage unaliased in subsequent outputs, of all subsequent paths
# if it is, we cached the output tensor and adjust storage liveness tracking to also
# check if the output tensor does not have an additional python reference.
# If a descendent node discovers it has an alias of a prior output, then the output
# will no longer be cached in the ancestor.
# The large majority of tensors are unaliased, and preserving aliased output tensors would add
# significant additional complexity with marginal gains
# The cached tensor outputs are added on the first execution, and cleared whenever we need
# to do subsequent recording
self.unaliased_in_all_paths: OutputList[bool] = []
self.cached_tensor_outputs: OutputList[Optional[Tensor]] = []
# if an output aliases a static, persistent input then the Storage of the
# persistent output will be set here
self.output_persistent_storage: OutputList[Optional[UntypedStorage]] = []
@ -788,8 +834,6 @@ class CUDAGraphNode:
self.run_graph()
outputs = self.reconstruct_outputs()
self._add_replayed_outputs(outputs)
self.debug_check_invariants_after_invocation()
return outputs
@ -797,35 +841,44 @@ class CUDAGraphNode:
def reconstruct_outputs(self):
"Reconstruct output tensors according to their saved metadata and alias information"
# The cpp function is constructing a new Tensor according to the saved output metadata
# For each element in the corresponding storage list:
# - if a Storage is contained, that will be used
# - if None is contained, a new Storage will be constructed
# - if an int is contained, the storage from the output list at that int will be used
storages_info: List[
Union[UntypedStorage, None, int]
] = self.prepare_storages_for_construction()
outputs_new = []
# Cached tensors will not yet be set on the first execution
# They are also cleared in checkpointing, so if we checkpoint this node
# and then execute it again we will need to repopulate cached tensors
if not self.cached_tensor_outputs:
self._initialize_cached_tensors()
# # We recreate the below logic in cpp to reduce overhead, since this is on the hot path
for storage_info, metadata in zip(storages_info, self.outputs_metadata):
outputs = []
for i, (storage_info, metadata) in enumerate(
zip(self.output_storage_alias, self.outputs_metadata)
):
if metadata is None:
outputs_new.append(None)
outputs.append(None)
continue
if storage_info is None:
s = self.create_storage(metadata)
elif isinstance(storage_info, UntypedStorage):
s = storage_info
else:
assert isinstance(storage_info, int)
s = outputs_new[storage_info].untyped_storage()
cached_t = self.cached_tensor_outputs[i]
if cached_t is not None:
# No need to update weakrefs, already correctly initialized
outputs.append(cached_t)
continue
outputs_new.append(
self._reconstruct_from_tensor_metadata(metadata, storage=s)
storage = self.prepare_alias_info_for_tensor_construction(
i, storage_info, metadata
)
return outputs_new
if isinstance(storage, UntypedStorage) or storage is None:
out = self._reconstruct_from_tensor_metadata(metadata, storage)
else:
assert isinstance(storage, int)
out = self._reconstruct_from_tensor_metadata(
metadata, outputs[storage].untyped_storage()
)
outputs.append(out)
if storage_info is not PersistentStaticStorage:
self.outputs_weakrefs[i].swap_weakref(out.untyped_storage()._weak_ref())
return outputs
def prepare_alias_info_for_tensor_construction(
self, out_index: int, out_alias_info: OutputAliasInfo, metadata: Dict[str, Any]
@ -839,7 +892,6 @@ class CUDAGraphNode:
if isinstance(out_alias_info, AliasesPriorGraphOutput):
depth, existing_output_index = out_alias_info.index
ref = self.path_weakrefs[depth][existing_output_index]
assert ref()
return torch.UntypedStorage._new_with_weak_ptr(ref())
assert isinstance(out_alias_info, AliasesNewOutput)
@ -928,6 +980,8 @@ class CUDAGraphNode:
# index from data pointer to index in outputs
output_new_storages_index: Dict[StorageDataPtr, int] = {}
self.unaliased_in_all_paths = [False for _ in range(len(outputs))]
for i, o in enumerate(outputs):
if o is None:
self.output_storage_alias.append(UnaliasedStorage)
@ -948,19 +1002,19 @@ class CUDAGraphNode:
path_ref = self._is_alias_of_live_recorded_tensor(o)
if path_ref is not None:
self._mark_prior_graph_output_as_aliased(path_ref)
self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref))
continue
if o.untyped_storage().data_ptr() in output_new_storages_index:
self.output_storage_alias.append(
AliasesNewOutput(
output_new_storages_index[o.untyped_storage().data_ptr()]
)
)
index = output_new_storages_index[o.untyped_storage().data_ptr()]
self.unaliased_in_all_paths[index] = False
self.output_storage_alias.append(AliasesNewOutput(index))
continue
output_new_storages_index[o.untyped_storage().data_ptr()] = i
self.output_storage_alias.append(UnaliasedStorage)
self.unaliased_in_all_paths[i] = True
if self.stack_traces is None:
self.stack_traces = [None for _ in range(len(outputs))]
@ -969,9 +1023,14 @@ class CUDAGraphNode:
outputs
), "Wrong number of stack traces passed in"
self._add_replayed_outputs(outputs)
self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs)
assert not self.outputs_weakrefs
for out, persisted_storage in zip(outputs, self.output_persistent_storage):
if out is None or persisted_storage is not None:
self.outputs_weakrefs.append(None)
else:
self.outputs_weakrefs.append(StorageWeakRefWrapper(out))
self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs)
self.checkpointed_caching_state = torch._C._cuda_getCheckpointState(
self.device, self.cuda_graphs_pool
)
@ -986,15 +1045,46 @@ class CUDAGraphNode:
if config.triton.slow_path_cudagraph_asserts:
check_memory_pool(self.cuda_graphs_pool, list(self.path_live_weakrefs()))
def _add_replayed_outputs(self, outputs):
self.outputs_weakrefs.clear()
def _mark_prior_graph_output_as_aliased(self, index: PathOutputIndex):
"Remove a graph output from the unaliased, cached tensors in an ancestor node"
depth, output_index = index
node = list(self._path_from_root)[depth]
node.unaliased_in_all_paths[output_index] = False
self.path_weakrefs[depth][output_index].remove_extra_reference()
for out, persistent_storage in zip(outputs, self.output_persistent_storage):
if out is None or persistent_storage is not None:
self.outputs_weakrefs.append(None)
def _initialize_cached_tensors(self):
# we should not be clearing output_weakrefs, and they should be set in the first
# record run
assert len(self.outputs_weakrefs) == len(self.outputs_metadata)
for i, (storage_info, metadata, make_cached) in enumerate(
zip(
self.output_storage_alias,
self.outputs_metadata,
self.unaliased_in_all_paths,
)
):
if not make_cached:
self.cached_tensor_outputs.append(None)
continue
self.outputs_weakrefs.append(StorageWeakRefWrapper(out))
assert storage_info is UnaliasedStorage
s = self.create_storage(metadata)
out = self._reconstruct_from_tensor_metadata(metadata, storage=s)
self_ref = weakref.ref(self)
# one reference in our array, and calling sys.getrefcount bumps the refcount by one
def check_refcount(i):
return self_ref().get_output_refcount(i) == 2
check = functools.partial(check_refcount, i=i)
self.outputs_weakrefs[i] = StorageWeakRefWrapper(out, extra_ref_check=check)
self.cached_tensor_outputs.append(out)
def get_output_refcount(self, index):
return sys.getrefcount(self.cached_tensor_outputs[index])
@property
def parent(self):
@ -1156,6 +1246,16 @@ class CUDAGraphNode:
if is_live(out):
yield out
def remove_node_cached_tensors(self):
self.cached_tensor_outputs.clear()
for i, unaliased in enumerate(self.unaliased_in_all_paths):
if unaliased:
self.outputs_weakrefs[i].remove_extra_reference()
def remove_path_cached_tensors(self):
for node in self._path_from_root:
node.remove_node_cached_tensors()
def path_live_weakrefs_and_stacktraces(
self,
) -> Generator[Tuple[StorageWeakRefWrapper, Optional[str]]]:
@ -1166,9 +1266,9 @@ class CUDAGraphNode:
yield out, self.path_stacktraces[i][j]
def clear_path_state(self):
"Clear the output lists of all nodes in the path and the storage cache"
for li in self.path_weakrefs:
li.clear()
"Clear the path state in this current executing node"
# this doesnt actually do anything right now, leaving it as placeholder
pass
@staticmethod
def _tensor_metadata(x, ignore_storage_offset=True):
@ -1520,6 +1620,22 @@ class CUDAGraphTreeManager:
# now, we are in a recording state !
return self.record_function(new_inputs, function_id)
def remove_all_cached_tensors(self):
"""
Remove all cached tensors in all nodes. Because cached tensors can hold gradients which in turn
might reference a backward which invokes a CUDA Graph Node, we have to manually clear them on shutdown
to avoid a reference cycle.
"""
nodes = []
for roots in self.roots.values():
nodes.extend(roots)
while nodes:
node = nodes.pop()
for children in node.children.values():
nodes.extend(children)
node.remove_node_cached_tensors()
def record_function(self, new_inputs, function_id) -> List[Optional[Tensor]]:
torch.cuda.synchronize()
node = CUDAGraphNode(
@ -1708,6 +1824,10 @@ class CUDAGraphTreeManager:
# currently we deallocate on instead of allowing stale recordings
stale_storages = []
# remove cached tensors, otherwise they would prevent memory from being
# reclaimed in subsequent recordings
self.current_node.remove_path_cached_tensors()
live_storages_wrappers = list(self.current_node.path_live_weakrefs())
live_storages_weak_refs = [t() for t in live_storages_wrappers]

View File

@ -1089,6 +1089,11 @@ static void registerCudaPluggableAllocator(PyObject* module) {
return (storage_impl->data_ptr().get_deleter() == alloc->raw_deleter());
});
m.def("_storage_Use_Count", [](size_t storage_impl_ptr) {
c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr;
return c10::raw::weak_intrusive_ptr::use_count(storage_impl);
});
m.def(
"_construct_CUDA_Tensor_From_Storage_And_Metadata",
[](py::dict& metadata, c10::Storage s) {