[fix] manual_seed{_all}: mem leak (#62534)

Summary:
Fixes: https://github.com/pytorch/pytorch/issues/55768

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62534

Reviewed By: nairbv

Differential Revision: D30103294

Pulled By: ezyang

fbshipit-source-id: d871ae869314dfd2d27544a51107ab752abfe452
This commit is contained in:
kshitij12345 2021-08-04 12:59:36 -07:00 committed by Facebook GitHub Bot
parent 89f898ebb5
commit 6f0abba04c
2 changed files with 43 additions and 6 deletions

View File

@ -32,6 +32,33 @@ _queued_calls = [] # don't invoke these until initialization occurs
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False)
_device_t = Union[_device, str, int, None]
class _LazySeedTracker:
# Since seeding is memory-less, only track the latest seed.
# Note: `manual_seed_all` followed by `manual_seed` overwrites
# the seed on current device. We track the order of **latest**
# calls between these two API.
def __init__(self):
self.manual_seed_all_cb = None
self.manual_seed_cb = None
self.call_order = []
def queue_seed_all(self, cb, traceback):
self.manual_seed_all_cb = (cb, traceback)
# update seed_all to be latest
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
def queue_seed(self, cb, traceback):
self.manual_seed_cb = (cb, traceback)
# update seed to be latest
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
def get_calls(self) -> List:
return self.call_order
_lazy_seed_tracker = _LazySeedTracker()
# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA
if hasattr(torch._C, '_CudaDeviceProperties'):
_CudaDeviceProperties = torch._C._CudaDeviceProperties
@ -111,16 +138,21 @@ def is_initialized():
return _initialized and not _is_in_bad_fork()
def _lazy_call(callable):
def _lazy_call(callable, **kwargs):
if is_initialized():
callable()
else:
# TODO(torch_deploy): this accesses linecache, which attempts to read the
# file system to get traceback info. Patch linecache or do something
# else here if this ends up being important.
# Don't store the actual traceback to avoid memory cycle
_queued_calls.append((callable, traceback.format_stack()))
global _lazy_seed_tracker
if kwargs.get("seed_all", False):
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
elif kwargs.get("seed", False):
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
else:
# Don't store the actual traceback to avoid memory cycle
_queued_calls.append((callable, traceback.format_stack()))
_lazy_call(_check_capability)
_lazy_call(_check_cubins)
@ -174,6 +206,11 @@ def _lazy_init():
# we need to just return without initializing in that case.
# However, we must not let any *other* threads in!
_tls.is_initializing = True
for calls in _lazy_seed_tracker.get_calls():
if calls:
_queued_calls.append(calls)
try:
for queued_call, orig_traceback in _queued_calls:
try:

View File

@ -92,7 +92,7 @@ def manual_seed(seed: int) -> None:
default_generator = torch.cuda.default_generators[idx]
default_generator.manual_seed(seed)
_lazy_call(cb)
_lazy_call(cb, seed=True)
def manual_seed_all(seed: int) -> None:
@ -110,7 +110,7 @@ def manual_seed_all(seed: int) -> None:
default_generator = torch.cuda.default_generators[i]
default_generator.manual_seed(seed)
_lazy_call(cb)
_lazy_call(cb, seed_all=True)
def seed() -> None: