Fix persistent worker exits before pin_memory thread (#71579)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71579

Fixes #1551

As the comment in the code, register a function to terminate persistent workers.
By adding a reference of these workers in `atexit`, it would prevent Python interpreter kills these persistent worker processes before `pin_memorh_thread` exits.
And, if users explicitly kills DataLoader iterator, such function in `atexit` would be a no-op.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33896537

Pulled By: ejguan

fbshipit-source-id: 36b57eac7523d8aa180180c2b61fc693ea4638ae
(cherry picked from commit 05add2ae0f)
This commit is contained in:
Erjia Guan 2022-02-01 15:52:49 -08:00 committed by PyTorch MergeBot
parent 10cc66bc78
commit 67a275c293
2 changed files with 57 additions and 0 deletions

View File

@ -2349,6 +2349,42 @@ except RuntimeError as e:
# and can cache values safely
dataset.start = i
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
@unittest.skipIf(IS_WINDOWS, "Needs fork")
def test_early_exit(self):
import subprocess
proc = subprocess.check_output([sys.executable, '-c', """\
import torch
from torch.utils.data import DataLoader, IterableDataset
class RandomDataset(IterableDataset):
def __init__(self, len, size):
super(RandomDataset).__init__()
self.len = len
self.size = size
def __iter__(self):
return self
def __next__(self):
if self.len <= 0:
raise StopIteration
self.len -= 1
return torch.randn(self.size)
if __name__ == '__main__':
dl = DataLoader(
RandomDataset(64, (28, 28)),
batch_size=16,
num_workers=2,
pin_memory=True,
persistent_workers=True,
multiprocessing_context="fork",
)
for _ in dl:
break
"""])
class NamedTupleDataset(Dataset):

View File

@ -946,6 +946,18 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
else:
self._data_queue = self._worker_result_queue
# In some rare cases, persistent workers (daemonic processes)
# would be terminated before `__del__` of iterator is invoked
# when main process exits
# It would cause failure when pin_memory_thread tries to read
# corrupted data from worker_result_queue
# atexit is used to shutdown thread and child processes in the
# right sequence before main process exits
if self._persistent_workers and self._pin_memory:
import atexit
for w in self._workers:
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
# .pid can be None only before process is spawned (not the case, so ignore)
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
_utils.signal_handling._set_SIGCHLD_handler()
@ -1333,5 +1345,14 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
# we kill the worker.
w.terminate()
# staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
@staticmethod
def _clean_up_worker(w):
try:
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
finally:
if w.is_alive():
w.terminate()
def __del__(self):
self._shutdown_workers()