mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Cache output tensors on execution (#98944)
Caches output tensors for the common case when the output Tensor storage is unaliased for all graph outputs in all paths. For these persisted tensors we adjust the liveness tracking by also checking that the output tensor does not have an additional python reference. I limit cached output tensors to be unaliased. If a descendent node discovers it has an alias of a prior output, then the aliased output will no longer be persisted in the ancestor. The large majority of tensors are unaliased, and preserving aliased output tensors would add significant additional complexity with marginal gains. For instance, when do checkpointing and re-recordings, we need to remove the persisted tensors otherwise it would prevent memory from being reclaimed. If a single persisted tensor was present in multiple paths then that would create an inter-path dependence which adds complexity. Additionally, each further caching of the output would affect the reference count of the other caches, and that reference count would also need to be adjusted depending on if a node was checkpointed. Still need to do a complete a run but for the models I tried makes the performance extremely close between trees and non trees impl. Pull Request resolved: https://github.com/pytorch/pytorch/pull/98944 Approved by: https://github.com/jansel, https://github.com/ngimel
This commit is contained in:
parent
93b64f0ad3
commit
472f46635e
|
|
@ -100,10 +100,17 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
||||||
class CudaGraphTreeTests(TestCase):
|
class CudaGraphTreeTests(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.prev_enabled = config.triton.cudagraphs
|
self.graph_stack = contextlib.ExitStack()
|
||||||
self.tapes_enabled = config.triton.cudagraph_trees
|
self.graph_stack.enter_context(
|
||||||
config.triton.cudagraphs = True
|
config.patch(
|
||||||
config.triton.cudagraph_trees = True
|
{
|
||||||
|
"triton.cudagraphs": True,
|
||||||
|
"triton.cudagraph_trees": True,
|
||||||
|
"triton.fast_path_cudagraph_asserts": True, # too slow
|
||||||
|
"triton.slow_path_cudagraph_asserts": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
self.device_idx = torch.rand([0], device="cuda").device.index
|
self.device_idx = torch.rand([0], device="cuda").device.index
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
|
@ -111,10 +118,12 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
config.triton.cudagraphs = self.prev_enabled
|
torch.cuda.empty_cache()
|
||||||
config.triton.cudagraph_trees = self.tapes_enabled
|
self.graph_stack.close()
|
||||||
|
|
||||||
self.assertIsNone(self.get_manager())
|
self.assertIsNone(self.get_manager())
|
||||||
self.assertEqual(all_live_block_count(), 0)
|
self.assertEqual(all_live_block_count(), 0)
|
||||||
|
self.assertEqual(len(get_all_cudagraph_segments()), 0)
|
||||||
warnings.resetwarnings()
|
warnings.resetwarnings()
|
||||||
|
|
||||||
def get_manager(self, device_index=None):
|
def get_manager(self, device_index=None):
|
||||||
|
|
@ -463,7 +472,11 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
||||||
ptr_to_ref[out.untyped_storage().data_ptr()],
|
ptr_to_ref[out.untyped_storage().data_ptr()],
|
||||||
out.untyped_storage()._cdata,
|
out.untyped_storage()._cdata,
|
||||||
)
|
)
|
||||||
|
del outs
|
||||||
|
del out
|
||||||
|
|
||||||
|
node = self.get_manager().current_node
|
||||||
|
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
|
||||||
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
||||||
|
|
||||||
def test_aliasing_static_ref(self):
|
def test_aliasing_static_ref(self):
|
||||||
|
|
@ -484,15 +497,63 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
||||||
x = torch.rand([10, 10], device="cuda", requires_grad=True)
|
x = torch.rand([10, 10], device="cuda", requires_grad=True)
|
||||||
param_c = cdata(m.weight)
|
param_c = cdata(m.weight)
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
# print("Runnng foo")
|
|
||||||
out1, alias_1, alias_2 = foo(m, x)
|
out1, alias_1, alias_2 = foo(m, x)
|
||||||
self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1)
|
self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1)
|
||||||
|
|
||||||
# print("Runnng foo2")
|
|
||||||
out2 = foo2(out1)
|
out2 = foo2(out1)
|
||||||
out2.sum().backward()
|
out2.sum().backward()
|
||||||
self.assertEqual(cdata(out1), cdata(out2))
|
self.assertEqual(cdata(out1), cdata(out2))
|
||||||
|
|
||||||
|
node = self.curr_node()
|
||||||
|
first_node = next(node._path_from_root)
|
||||||
|
self.assertFalse(first_node.unaliased_in_all_paths[0])
|
||||||
|
self.assertTrue(first_node.cached_tensor_outputs[0] is None)
|
||||||
|
|
||||||
|
def test_checkpointing_resets_persistent_refs(self):
|
||||||
|
@torch.compile(mode="reduce-overhead")
|
||||||
|
def foo(x):
|
||||||
|
return x @ x
|
||||||
|
|
||||||
|
def inp():
|
||||||
|
return torch.rand([20, 20], device="cuda", requires_grad=False)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
foo(inp())
|
||||||
|
|
||||||
|
self.assertEqual(self.num_checkpoints(), 0)
|
||||||
|
|
||||||
|
out = foo(inp())
|
||||||
|
out_id = id(out)
|
||||||
|
del out
|
||||||
|
self.assertEqual(id(foo(inp())), out_id)
|
||||||
|
|
||||||
|
@torch.compile(mode="reduce-overhead")
|
||||||
|
def foo2(x):
|
||||||
|
return x[0], x @ x
|
||||||
|
|
||||||
|
for i in range(2):
|
||||||
|
out = foo(inp())
|
||||||
|
|
||||||
|
from torch._dynamo.mutation_guard import GenerationTracker
|
||||||
|
|
||||||
|
GenerationTracker.generation -= 1
|
||||||
|
|
||||||
|
out_alias, out2 = foo2(out)
|
||||||
|
del out_alias
|
||||||
|
|
||||||
|
self.assertEqual(all_live_block_count(), 2)
|
||||||
|
del out
|
||||||
|
self.assertEqual(all_live_block_count(), 1)
|
||||||
|
del out2
|
||||||
|
self.assertEqual(all_live_block_count(), 0)
|
||||||
|
|
||||||
|
self.assertEqual(self.num_checkpoints(), i + 1)
|
||||||
|
|
||||||
|
new_out = foo(inp())
|
||||||
|
curr_node = self.curr_node()
|
||||||
|
self.assertFalse(curr_node.unaliased_in_all_paths[0])
|
||||||
|
self.assertFalse(out_id == id(new_out))
|
||||||
|
|
||||||
def test_aliased_static_parameter(self):
|
def test_aliased_static_parameter(self):
|
||||||
inp = torch.rand([20, 20], device="cuda")
|
inp = torch.rand([20, 20], device="cuda")
|
||||||
|
|
||||||
|
|
@ -507,6 +568,29 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
||||||
out = foo_cg([inp])[0]
|
out = foo_cg([inp])[0]
|
||||||
self.assertEqual(cdata(inp), cdata(out))
|
self.assertEqual(cdata(inp), cdata(out))
|
||||||
|
|
||||||
|
node = self.curr_node()
|
||||||
|
self.assertEqual(node.cached_tensor_outputs, [None])
|
||||||
|
self.assertEqual(node.unaliased_in_all_paths, [False])
|
||||||
|
|
||||||
|
def test_output_alias(self):
|
||||||
|
inp = torch.rand([20, 20], device="cuda")
|
||||||
|
|
||||||
|
def foo(args):
|
||||||
|
x = args[0]
|
||||||
|
args.clear()
|
||||||
|
out = x + x
|
||||||
|
return (x, x[0])
|
||||||
|
|
||||||
|
foo_cg = self.cudagraphify_impl(foo, [inp])
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
out_1, out_2 = foo_cg([inp])
|
||||||
|
self.assertEqual(cdata(out_1), cdata(out_2))
|
||||||
|
del out_1, out_2
|
||||||
|
self.assertEqual(len(list(self.curr_node().path_live_weakrefs())), 0)
|
||||||
|
|
||||||
|
self.assertEqual(self.curr_node().cached_tensor_outputs, [None, None])
|
||||||
|
|
||||||
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
||||||
def test_aliased_output_checkpoint(self):
|
def test_aliased_output_checkpoint(self):
|
||||||
def foo(args):
|
def foo(args):
|
||||||
|
|
@ -543,6 +627,28 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
||||||
del x
|
del x
|
||||||
self.assertEqual(all_live_block_count(), 0)
|
self.assertEqual(all_live_block_count(), 0)
|
||||||
|
|
||||||
|
def test_peristed_output_livenes(self):
|
||||||
|
@torch.compile
|
||||||
|
def foo(x):
|
||||||
|
return x + x
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
foo(torch.rand([2, 2], device="cuda"))
|
||||||
|
|
||||||
|
node = self.get_manager().current_node
|
||||||
|
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
|
||||||
|
|
||||||
|
out = foo(torch.rand([2, 2], device="cuda"))
|
||||||
|
self.assertTrue(out is node.cached_tensor_outputs[0])
|
||||||
|
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
|
||||||
|
|
||||||
|
out_ref = out[0:]
|
||||||
|
del out
|
||||||
|
self.assertEqual(len(list(node.path_live_weakrefs())), 1)
|
||||||
|
|
||||||
|
del out_ref
|
||||||
|
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
|
||||||
|
|
||||||
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
|
||||||
def test_tensor_no_longer_in_pool(self):
|
def test_tensor_no_longer_in_pool(self):
|
||||||
def foo(args):
|
def foo(args):
|
||||||
|
|
@ -771,11 +877,13 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
||||||
|
|
||||||
inp = torch.rand([4], device="cuda", requires_grad=True)
|
inp = torch.rand([4], device="cuda", requires_grad=True)
|
||||||
streams = set()
|
streams = set()
|
||||||
|
streams_init = {seg["stream"] for seg in get_all_cudagraph_segments()}
|
||||||
for _ in range(3):
|
for _ in range(4):
|
||||||
foo(inp).sum().backward()
|
foo(inp).sum().backward()
|
||||||
|
|
||||||
streams = {seg["stream"] for seg in get_all_cudagraph_segments()}
|
streams = {
|
||||||
|
seg["stream"] for seg in get_all_cudagraph_segments()
|
||||||
|
} - streams_init
|
||||||
self.assertEqual(len(streams), 1)
|
self.assertEqual(len(streams), 1)
|
||||||
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
self.assertFalse(self.get_manager().new_graph_id().id == 0)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1529,6 +1529,12 @@ class _TorchCompileInductorWrapper:
|
||||||
|
|
||||||
return compile_fx(model_, inputs_, config_patches=self.config)
|
return compile_fx(model_, inputs_, config_patches=self.config)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
from torch._inductor import config
|
||||||
|
if "triton.cudagraphs" in self.config or config.triton.cudagraphs:
|
||||||
|
if self.config.get("triton.cudagraphs", True):
|
||||||
|
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
|
||||||
|
reset_cudagraph_trees()
|
||||||
|
|
||||||
def compile(model: Optional[Callable] = None, *,
|
def compile(model: Optional[Callable] = None, *,
|
||||||
fullgraph: builtins.bool = False,
|
fullgraph: builtins.bool = False,
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,8 @@ def reset():
|
||||||
orig_code_map.clear()
|
orig_code_map.clear()
|
||||||
guard_failures.clear()
|
guard_failures.clear()
|
||||||
resume_execution.ContinueExecutionCache.cache.clear()
|
resume_execution.ContinueExecutionCache.cache.clear()
|
||||||
|
if hasattr(eval_frame.most_recent_backend, "reset"):
|
||||||
|
eval_frame.most_recent_backend.reset()
|
||||||
eval_frame.most_recent_backend = None
|
eval_frame.most_recent_backend = None
|
||||||
compilation_metrics.clear()
|
compilation_metrics.clear()
|
||||||
reset_frame_count()
|
reset_frame_count()
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ import dataclasses
|
||||||
import functools
|
import functools
|
||||||
import gc
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
import weakref
|
import weakref
|
||||||
|
|
@ -245,6 +246,22 @@ torch._C._stash_obj_in_tls("tree_manager_containers", local.tree_manager_contain
|
||||||
torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks)
|
torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_cudagraph_trees():
|
||||||
|
"Clear all cudagraph trees"
|
||||||
|
# see remove_all_cached_tensors below for why this is necessary
|
||||||
|
container_dict = get_obj(local, "tree_manager_containers")
|
||||||
|
locks_dict = get_obj(local, "tree_manager_locks")
|
||||||
|
for device, lock in locks_dict.items():
|
||||||
|
with lock:
|
||||||
|
container = container_dict.get(device)
|
||||||
|
if not container or not container.tree_manager:
|
||||||
|
continue
|
||||||
|
|
||||||
|
container.tree_manager.remove_all_cached_tensors()
|
||||||
|
|
||||||
|
container_dict.clear()
|
||||||
|
|
||||||
|
|
||||||
def get_obj(local, attr_name):
|
def get_obj(local, attr_name):
|
||||||
if hasattr(local, attr_name):
|
if hasattr(local, attr_name):
|
||||||
return getattr(local, attr_name)
|
return getattr(local, attr_name)
|
||||||
|
|
@ -324,11 +341,15 @@ class StorageWeakRefWrapper:
|
||||||
Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked.
|
Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ["ref", "_data_ptr"]
|
__slots__ = ["ref", "_data_ptr", "extra_ref_check"]
|
||||||
|
|
||||||
storage_ref: Optional[StorageWeakRef]
|
storage_ref: Optional[StorageWeakRef]
|
||||||
|
|
||||||
def __init__(self, inp: Union[Tensor, UntypedStorage]):
|
def __init__(
|
||||||
|
self,
|
||||||
|
inp: Union[Tensor, UntypedStorage],
|
||||||
|
extra_ref_check: Optional[Callable[[], None]] = None,
|
||||||
|
):
|
||||||
if isinstance(inp, Tensor):
|
if isinstance(inp, Tensor):
|
||||||
stor = inp.untyped_storage()
|
stor = inp.untyped_storage()
|
||||||
else:
|
else:
|
||||||
|
|
@ -336,28 +357,41 @@ class StorageWeakRefWrapper:
|
||||||
stor = inp
|
stor = inp
|
||||||
self.ref = StorageWeakRef(stor)
|
self.ref = StorageWeakRef(stor)
|
||||||
self._data_ptr = stor.data_ptr()
|
self._data_ptr = stor.data_ptr()
|
||||||
|
self.extra_ref_check = extra_ref_check
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_weakref_and_data_ptr(cls, cdata, data_ptr):
|
def from_weakref_and_data_ptr(cls, cdata, data_ptr, extra_ref_check=None):
|
||||||
instance = cls.__new__(cls)
|
instance = cls.__new__(cls)
|
||||||
instance._data_ptr = data_ptr
|
instance._data_ptr = data_ptr
|
||||||
instance.ref = StorageWeakRef.from_weakref(cdata)
|
instance.ref = StorageWeakRef.from_weakref(cdata)
|
||||||
|
instance.extra_ref_check = extra_ref_check
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
def __call__(self) -> Optional[StorageWeakRefPointer]:
|
def __call__(self) -> Optional[StorageWeakRefPointer]:
|
||||||
if self.ref is None:
|
if self.expired():
|
||||||
return None
|
|
||||||
|
|
||||||
if self.ref.expired():
|
|
||||||
self.ref = None
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.ref.cdata
|
return self.ref.cdata
|
||||||
|
|
||||||
|
def swap_weakref(self, cdata):
|
||||||
|
self.ref.__del__()
|
||||||
|
self.ref.cdata = cdata
|
||||||
|
|
||||||
def data_ptr(self) -> int:
|
def data_ptr(self) -> int:
|
||||||
"NB: returns the data ptr even if the storage has expired"
|
"NB: returns the data ptr even if the storage has expired"
|
||||||
return self._data_ptr
|
return self._data_ptr
|
||||||
|
|
||||||
|
def remove_extra_reference(self):
|
||||||
|
self.extra_ref_check = None
|
||||||
|
|
||||||
|
def expired(self):
|
||||||
|
if self.extra_ref_check is not None and not self.extra_ref_check():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# if extra_ref_check is not None we expect an additional reference
|
||||||
|
stor_count = torch._C._storage_Use_Count(self.ref.cdata)
|
||||||
|
return (stor_count - (self.extra_ref_check is not None)) == 0
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
if self.ref is None or self.ref.expired():
|
if self.ref is None or self.ref.expired():
|
||||||
return f"StorageWeakRefWrapper to {self.data_ptr()}; dead"
|
return f"StorageWeakRefWrapper to {self.data_ptr()}; dead"
|
||||||
|
|
@ -725,6 +759,18 @@ class CUDAGraphNode:
|
||||||
# - An alias of an output already created in the reconstructed outputs
|
# - An alias of an output already created in the reconstructed outputs
|
||||||
self.output_storage_alias: OutputList[OutputAliasInfo] = []
|
self.output_storage_alias: OutputList[OutputAliasInfo] = []
|
||||||
|
|
||||||
|
# is the output Storage unaliased in subsequent outputs, of all subsequent paths
|
||||||
|
# if it is, we cached the output tensor and adjust storage liveness tracking to also
|
||||||
|
# check if the output tensor does not have an additional python reference.
|
||||||
|
# If a descendent node discovers it has an alias of a prior output, then the output
|
||||||
|
# will no longer be cached in the ancestor.
|
||||||
|
# The large majority of tensors are unaliased, and preserving aliased output tensors would add
|
||||||
|
# significant additional complexity with marginal gains
|
||||||
|
# The cached tensor outputs are added on the first execution, and cleared whenever we need
|
||||||
|
# to do subsequent recording
|
||||||
|
self.unaliased_in_all_paths: OutputList[bool] = []
|
||||||
|
self.cached_tensor_outputs: OutputList[Optional[Tensor]] = []
|
||||||
|
|
||||||
# if an output aliases a static, persistent input then the Storage of the
|
# if an output aliases a static, persistent input then the Storage of the
|
||||||
# persistent output will be set here
|
# persistent output will be set here
|
||||||
self.output_persistent_storage: OutputList[Optional[UntypedStorage]] = []
|
self.output_persistent_storage: OutputList[Optional[UntypedStorage]] = []
|
||||||
|
|
@ -788,8 +834,6 @@ class CUDAGraphNode:
|
||||||
self.run_graph()
|
self.run_graph()
|
||||||
|
|
||||||
outputs = self.reconstruct_outputs()
|
outputs = self.reconstruct_outputs()
|
||||||
|
|
||||||
self._add_replayed_outputs(outputs)
|
|
||||||
self.debug_check_invariants_after_invocation()
|
self.debug_check_invariants_after_invocation()
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
@ -797,35 +841,44 @@ class CUDAGraphNode:
|
||||||
def reconstruct_outputs(self):
|
def reconstruct_outputs(self):
|
||||||
"Reconstruct output tensors according to their saved metadata and alias information"
|
"Reconstruct output tensors according to their saved metadata and alias information"
|
||||||
|
|
||||||
# The cpp function is constructing a new Tensor according to the saved output metadata
|
# Cached tensors will not yet be set on the first execution
|
||||||
# For each element in the corresponding storage list:
|
# They are also cleared in checkpointing, so if we checkpoint this node
|
||||||
# - if a Storage is contained, that will be used
|
# and then execute it again we will need to repopulate cached tensors
|
||||||
# - if None is contained, a new Storage will be constructed
|
if not self.cached_tensor_outputs:
|
||||||
# - if an int is contained, the storage from the output list at that int will be used
|
self._initialize_cached_tensors()
|
||||||
storages_info: List[
|
|
||||||
Union[UntypedStorage, None, int]
|
|
||||||
] = self.prepare_storages_for_construction()
|
|
||||||
outputs_new = []
|
|
||||||
|
|
||||||
# # We recreate the below logic in cpp to reduce overhead, since this is on the hot path
|
outputs = []
|
||||||
for storage_info, metadata in zip(storages_info, self.outputs_metadata):
|
|
||||||
|
for i, (storage_info, metadata) in enumerate(
|
||||||
|
zip(self.output_storage_alias, self.outputs_metadata)
|
||||||
|
):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
outputs_new.append(None)
|
outputs.append(None)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if storage_info is None:
|
cached_t = self.cached_tensor_outputs[i]
|
||||||
s = self.create_storage(metadata)
|
if cached_t is not None:
|
||||||
elif isinstance(storage_info, UntypedStorage):
|
# No need to update weakrefs, already correctly initialized
|
||||||
s = storage_info
|
outputs.append(cached_t)
|
||||||
else:
|
continue
|
||||||
assert isinstance(storage_info, int)
|
|
||||||
s = outputs_new[storage_info].untyped_storage()
|
|
||||||
|
|
||||||
outputs_new.append(
|
storage = self.prepare_alias_info_for_tensor_construction(
|
||||||
self._reconstruct_from_tensor_metadata(metadata, storage=s)
|
i, storage_info, metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
return outputs_new
|
if isinstance(storage, UntypedStorage) or storage is None:
|
||||||
|
out = self._reconstruct_from_tensor_metadata(metadata, storage)
|
||||||
|
else:
|
||||||
|
assert isinstance(storage, int)
|
||||||
|
out = self._reconstruct_from_tensor_metadata(
|
||||||
|
metadata, outputs[storage].untyped_storage()
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs.append(out)
|
||||||
|
if storage_info is not PersistentStaticStorage:
|
||||||
|
self.outputs_weakrefs[i].swap_weakref(out.untyped_storage()._weak_ref())
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
def prepare_alias_info_for_tensor_construction(
|
def prepare_alias_info_for_tensor_construction(
|
||||||
self, out_index: int, out_alias_info: OutputAliasInfo, metadata: Dict[str, Any]
|
self, out_index: int, out_alias_info: OutputAliasInfo, metadata: Dict[str, Any]
|
||||||
|
|
@ -839,7 +892,6 @@ class CUDAGraphNode:
|
||||||
if isinstance(out_alias_info, AliasesPriorGraphOutput):
|
if isinstance(out_alias_info, AliasesPriorGraphOutput):
|
||||||
depth, existing_output_index = out_alias_info.index
|
depth, existing_output_index = out_alias_info.index
|
||||||
ref = self.path_weakrefs[depth][existing_output_index]
|
ref = self.path_weakrefs[depth][existing_output_index]
|
||||||
assert ref()
|
|
||||||
return torch.UntypedStorage._new_with_weak_ptr(ref())
|
return torch.UntypedStorage._new_with_weak_ptr(ref())
|
||||||
|
|
||||||
assert isinstance(out_alias_info, AliasesNewOutput)
|
assert isinstance(out_alias_info, AliasesNewOutput)
|
||||||
|
|
@ -928,6 +980,8 @@ class CUDAGraphNode:
|
||||||
# index from data pointer to index in outputs
|
# index from data pointer to index in outputs
|
||||||
output_new_storages_index: Dict[StorageDataPtr, int] = {}
|
output_new_storages_index: Dict[StorageDataPtr, int] = {}
|
||||||
|
|
||||||
|
self.unaliased_in_all_paths = [False for _ in range(len(outputs))]
|
||||||
|
|
||||||
for i, o in enumerate(outputs):
|
for i, o in enumerate(outputs):
|
||||||
if o is None:
|
if o is None:
|
||||||
self.output_storage_alias.append(UnaliasedStorage)
|
self.output_storage_alias.append(UnaliasedStorage)
|
||||||
|
|
@ -948,19 +1002,19 @@ class CUDAGraphNode:
|
||||||
|
|
||||||
path_ref = self._is_alias_of_live_recorded_tensor(o)
|
path_ref = self._is_alias_of_live_recorded_tensor(o)
|
||||||
if path_ref is not None:
|
if path_ref is not None:
|
||||||
|
self._mark_prior_graph_output_as_aliased(path_ref)
|
||||||
self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref))
|
self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if o.untyped_storage().data_ptr() in output_new_storages_index:
|
if o.untyped_storage().data_ptr() in output_new_storages_index:
|
||||||
self.output_storage_alias.append(
|
index = output_new_storages_index[o.untyped_storage().data_ptr()]
|
||||||
AliasesNewOutput(
|
self.unaliased_in_all_paths[index] = False
|
||||||
output_new_storages_index[o.untyped_storage().data_ptr()]
|
self.output_storage_alias.append(AliasesNewOutput(index))
|
||||||
)
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
output_new_storages_index[o.untyped_storage().data_ptr()] = i
|
output_new_storages_index[o.untyped_storage().data_ptr()] = i
|
||||||
self.output_storage_alias.append(UnaliasedStorage)
|
self.output_storage_alias.append(UnaliasedStorage)
|
||||||
|
self.unaliased_in_all_paths[i] = True
|
||||||
|
|
||||||
if self.stack_traces is None:
|
if self.stack_traces is None:
|
||||||
self.stack_traces = [None for _ in range(len(outputs))]
|
self.stack_traces = [None for _ in range(len(outputs))]
|
||||||
|
|
@ -969,9 +1023,14 @@ class CUDAGraphNode:
|
||||||
outputs
|
outputs
|
||||||
), "Wrong number of stack traces passed in"
|
), "Wrong number of stack traces passed in"
|
||||||
|
|
||||||
self._add_replayed_outputs(outputs)
|
assert not self.outputs_weakrefs
|
||||||
self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs)
|
for out, persisted_storage in zip(outputs, self.output_persistent_storage):
|
||||||
|
if out is None or persisted_storage is not None:
|
||||||
|
self.outputs_weakrefs.append(None)
|
||||||
|
else:
|
||||||
|
self.outputs_weakrefs.append(StorageWeakRefWrapper(out))
|
||||||
|
|
||||||
|
self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs)
|
||||||
self.checkpointed_caching_state = torch._C._cuda_getCheckpointState(
|
self.checkpointed_caching_state = torch._C._cuda_getCheckpointState(
|
||||||
self.device, self.cuda_graphs_pool
|
self.device, self.cuda_graphs_pool
|
||||||
)
|
)
|
||||||
|
|
@ -986,15 +1045,46 @@ class CUDAGraphNode:
|
||||||
if config.triton.slow_path_cudagraph_asserts:
|
if config.triton.slow_path_cudagraph_asserts:
|
||||||
check_memory_pool(self.cuda_graphs_pool, list(self.path_live_weakrefs()))
|
check_memory_pool(self.cuda_graphs_pool, list(self.path_live_weakrefs()))
|
||||||
|
|
||||||
def _add_replayed_outputs(self, outputs):
|
def _mark_prior_graph_output_as_aliased(self, index: PathOutputIndex):
|
||||||
self.outputs_weakrefs.clear()
|
"Remove a graph output from the unaliased, cached tensors in an ancestor node"
|
||||||
|
depth, output_index = index
|
||||||
|
node = list(self._path_from_root)[depth]
|
||||||
|
node.unaliased_in_all_paths[output_index] = False
|
||||||
|
self.path_weakrefs[depth][output_index].remove_extra_reference()
|
||||||
|
|
||||||
for out, persistent_storage in zip(outputs, self.output_persistent_storage):
|
def _initialize_cached_tensors(self):
|
||||||
if out is None or persistent_storage is not None:
|
# we should not be clearing output_weakrefs, and they should be set in the first
|
||||||
self.outputs_weakrefs.append(None)
|
# record run
|
||||||
|
assert len(self.outputs_weakrefs) == len(self.outputs_metadata)
|
||||||
|
|
||||||
|
for i, (storage_info, metadata, make_cached) in enumerate(
|
||||||
|
zip(
|
||||||
|
self.output_storage_alias,
|
||||||
|
self.outputs_metadata,
|
||||||
|
self.unaliased_in_all_paths,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if not make_cached:
|
||||||
|
self.cached_tensor_outputs.append(None)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.outputs_weakrefs.append(StorageWeakRefWrapper(out))
|
assert storage_info is UnaliasedStorage
|
||||||
|
s = self.create_storage(metadata)
|
||||||
|
out = self._reconstruct_from_tensor_metadata(metadata, storage=s)
|
||||||
|
|
||||||
|
self_ref = weakref.ref(self)
|
||||||
|
|
||||||
|
# one reference in our array, and calling sys.getrefcount bumps the refcount by one
|
||||||
|
def check_refcount(i):
|
||||||
|
return self_ref().get_output_refcount(i) == 2
|
||||||
|
|
||||||
|
check = functools.partial(check_refcount, i=i)
|
||||||
|
|
||||||
|
self.outputs_weakrefs[i] = StorageWeakRefWrapper(out, extra_ref_check=check)
|
||||||
|
self.cached_tensor_outputs.append(out)
|
||||||
|
|
||||||
|
def get_output_refcount(self, index):
|
||||||
|
return sys.getrefcount(self.cached_tensor_outputs[index])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parent(self):
|
def parent(self):
|
||||||
|
|
@ -1156,6 +1246,16 @@ class CUDAGraphNode:
|
||||||
if is_live(out):
|
if is_live(out):
|
||||||
yield out
|
yield out
|
||||||
|
|
||||||
|
def remove_node_cached_tensors(self):
|
||||||
|
self.cached_tensor_outputs.clear()
|
||||||
|
for i, unaliased in enumerate(self.unaliased_in_all_paths):
|
||||||
|
if unaliased:
|
||||||
|
self.outputs_weakrefs[i].remove_extra_reference()
|
||||||
|
|
||||||
|
def remove_path_cached_tensors(self):
|
||||||
|
for node in self._path_from_root:
|
||||||
|
node.remove_node_cached_tensors()
|
||||||
|
|
||||||
def path_live_weakrefs_and_stacktraces(
|
def path_live_weakrefs_and_stacktraces(
|
||||||
self,
|
self,
|
||||||
) -> Generator[Tuple[StorageWeakRefWrapper, Optional[str]]]:
|
) -> Generator[Tuple[StorageWeakRefWrapper, Optional[str]]]:
|
||||||
|
|
@ -1166,9 +1266,9 @@ class CUDAGraphNode:
|
||||||
yield out, self.path_stacktraces[i][j]
|
yield out, self.path_stacktraces[i][j]
|
||||||
|
|
||||||
def clear_path_state(self):
|
def clear_path_state(self):
|
||||||
"Clear the output lists of all nodes in the path and the storage cache"
|
"Clear the path state in this current executing node"
|
||||||
for li in self.path_weakrefs:
|
# this doesnt actually do anything right now, leaving it as placeholder
|
||||||
li.clear()
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _tensor_metadata(x, ignore_storage_offset=True):
|
def _tensor_metadata(x, ignore_storage_offset=True):
|
||||||
|
|
@ -1520,6 +1620,22 @@ class CUDAGraphTreeManager:
|
||||||
# now, we are in a recording state !
|
# now, we are in a recording state !
|
||||||
return self.record_function(new_inputs, function_id)
|
return self.record_function(new_inputs, function_id)
|
||||||
|
|
||||||
|
def remove_all_cached_tensors(self):
|
||||||
|
"""
|
||||||
|
Remove all cached tensors in all nodes. Because cached tensors can hold gradients which in turn
|
||||||
|
might reference a backward which invokes a CUDA Graph Node, we have to manually clear them on shutdown
|
||||||
|
to avoid a reference cycle.
|
||||||
|
"""
|
||||||
|
nodes = []
|
||||||
|
for roots in self.roots.values():
|
||||||
|
nodes.extend(roots)
|
||||||
|
|
||||||
|
while nodes:
|
||||||
|
node = nodes.pop()
|
||||||
|
for children in node.children.values():
|
||||||
|
nodes.extend(children)
|
||||||
|
node.remove_node_cached_tensors()
|
||||||
|
|
||||||
def record_function(self, new_inputs, function_id) -> List[Optional[Tensor]]:
|
def record_function(self, new_inputs, function_id) -> List[Optional[Tensor]]:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
node = CUDAGraphNode(
|
node = CUDAGraphNode(
|
||||||
|
|
@ -1708,6 +1824,10 @@ class CUDAGraphTreeManager:
|
||||||
|
|
||||||
# currently we deallocate on instead of allowing stale recordings
|
# currently we deallocate on instead of allowing stale recordings
|
||||||
stale_storages = []
|
stale_storages = []
|
||||||
|
|
||||||
|
# remove cached tensors, otherwise they would prevent memory from being
|
||||||
|
# reclaimed in subsequent recordings
|
||||||
|
self.current_node.remove_path_cached_tensors()
|
||||||
live_storages_wrappers = list(self.current_node.path_live_weakrefs())
|
live_storages_wrappers = list(self.current_node.path_live_weakrefs())
|
||||||
|
|
||||||
live_storages_weak_refs = [t() for t in live_storages_wrappers]
|
live_storages_weak_refs = [t() for t in live_storages_wrappers]
|
||||||
|
|
|
||||||
|
|
@ -1089,6 +1089,11 @@ static void registerCudaPluggableAllocator(PyObject* module) {
|
||||||
return (storage_impl->data_ptr().get_deleter() == alloc->raw_deleter());
|
return (storage_impl->data_ptr().get_deleter() == alloc->raw_deleter());
|
||||||
});
|
});
|
||||||
|
|
||||||
|
m.def("_storage_Use_Count", [](size_t storage_impl_ptr) {
|
||||||
|
c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr;
|
||||||
|
return c10::raw::weak_intrusive_ptr::use_count(storage_impl);
|
||||||
|
});
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"_construct_CUDA_Tensor_From_Storage_And_Metadata",
|
"_construct_CUDA_Tensor_From_Storage_And_Metadata",
|
||||||
[](py::dict& metadata, c10::Storage s) {
|
[](py::dict& metadata, c10::Storage s) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user