mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fix] manual_seed{_all}: mem leak (#62534)
Summary: Fixes: https://github.com/pytorch/pytorch/issues/55768 Pull Request resolved: https://github.com/pytorch/pytorch/pull/62534 Reviewed By: nairbv Differential Revision: D30103294 Pulled By: ezyang fbshipit-source-id: d871ae869314dfd2d27544a51107ab752abfe452
This commit is contained in:
parent
89f898ebb5
commit
6f0abba04c
|
|
@ -32,6 +32,33 @@ _queued_calls = [] # don't invoke these until initialization occurs
|
|||
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
|
||||
_device_t = Union[_device, str, int, None]
|
||||
|
||||
|
||||
class _LazySeedTracker:
|
||||
# Since seeding is memory-less, only track the latest seed.
|
||||
# Note: `manual_seed_all` followed by `manual_seed` overwrites
|
||||
# the seed on current device. We track the order of **latest**
|
||||
# calls between these two API.
|
||||
def __init__(self):
|
||||
self.manual_seed_all_cb = None
|
||||
self.manual_seed_cb = None
|
||||
self.call_order = []
|
||||
|
||||
def queue_seed_all(self, cb, traceback):
|
||||
self.manual_seed_all_cb = (cb, traceback)
|
||||
# update seed_all to be latest
|
||||
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
|
||||
|
||||
def queue_seed(self, cb, traceback):
|
||||
self.manual_seed_cb = (cb, traceback)
|
||||
# update seed to be latest
|
||||
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
|
||||
|
||||
def get_calls(self) -> List:
|
||||
return self.call_order
|
||||
|
||||
|
||||
_lazy_seed_tracker = _LazySeedTracker()
|
||||
|
||||
# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA
|
||||
if hasattr(torch._C, '_CudaDeviceProperties'):
|
||||
_CudaDeviceProperties = torch._C._CudaDeviceProperties
|
||||
|
|
@ -111,16 +138,21 @@ def is_initialized():
|
|||
return _initialized and not _is_in_bad_fork()
|
||||
|
||||
|
||||
def _lazy_call(callable):
|
||||
def _lazy_call(callable, **kwargs):
|
||||
if is_initialized():
|
||||
callable()
|
||||
else:
|
||||
# TODO(torch_deploy): this accesses linecache, which attempts to read the
|
||||
# file system to get traceback info. Patch linecache or do something
|
||||
# else here if this ends up being important.
|
||||
|
||||
# Don't store the actual traceback to avoid memory cycle
|
||||
_queued_calls.append((callable, traceback.format_stack()))
|
||||
global _lazy_seed_tracker
|
||||
if kwargs.get("seed_all", False):
|
||||
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
|
||||
elif kwargs.get("seed", False):
|
||||
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
|
||||
else:
|
||||
# Don't store the actual traceback to avoid memory cycle
|
||||
_queued_calls.append((callable, traceback.format_stack()))
|
||||
|
||||
_lazy_call(_check_capability)
|
||||
_lazy_call(_check_cubins)
|
||||
|
|
@ -174,6 +206,11 @@ def _lazy_init():
|
|||
# we need to just return without initializing in that case.
|
||||
# However, we must not let any *other* threads in!
|
||||
_tls.is_initializing = True
|
||||
|
||||
for calls in _lazy_seed_tracker.get_calls():
|
||||
if calls:
|
||||
_queued_calls.append(calls)
|
||||
|
||||
try:
|
||||
for queued_call, orig_traceback in _queued_calls:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ def manual_seed(seed: int) -> None:
|
|||
default_generator = torch.cuda.default_generators[idx]
|
||||
default_generator.manual_seed(seed)
|
||||
|
||||
_lazy_call(cb)
|
||||
_lazy_call(cb, seed=True)
|
||||
|
||||
|
||||
def manual_seed_all(seed: int) -> None:
|
||||
|
|
@ -110,7 +110,7 @@ def manual_seed_all(seed: int) -> None:
|
|||
default_generator = torch.cuda.default_generators[i]
|
||||
default_generator.manual_seed(seed)
|
||||
|
||||
_lazy_call(cb)
|
||||
_lazy_call(cb, seed_all=True)
|
||||
|
||||
|
||||
def seed() -> None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user