mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix persistent worker exits before pin_memory thread (#71579)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71579
Fixes #1551
As the comment in the code, register a function to terminate persistent workers.
By adding a reference of these workers in `atexit`, it would prevent Python interpreter kills these persistent worker processes before `pin_memorh_thread` exits.
And, if users explicitly kills DataLoader iterator, such function in `atexit` would be a no-op.
Test Plan: Imported from OSS
Reviewed By: VitalyFedyunin
Differential Revision: D33896537
Pulled By: ejguan
fbshipit-source-id: 36b57eac7523d8aa180180c2b61fc693ea4638ae
(cherry picked from commit 05add2ae0f)
This commit is contained in:
parent
10cc66bc78
commit
67a275c293
|
|
@ -2349,6 +2349,42 @@ except RuntimeError as e:
|
|||
# and can cache values safely
|
||||
dataset.start = i
|
||||
|
||||
@unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI")
|
||||
@unittest.skipIf(IS_WINDOWS, "Needs fork")
|
||||
def test_early_exit(self):
|
||||
import subprocess
|
||||
proc = subprocess.check_output([sys.executable, '-c', """\
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
|
||||
class RandomDataset(IterableDataset):
|
||||
def __init__(self, len, size):
|
||||
super(RandomDataset).__init__()
|
||||
self.len = len
|
||||
self.size = size
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.len <= 0:
|
||||
raise StopIteration
|
||||
self.len -= 1
|
||||
return torch.randn(self.size)
|
||||
|
||||
if __name__ == '__main__':
|
||||
dl = DataLoader(
|
||||
RandomDataset(64, (28, 28)),
|
||||
batch_size=16,
|
||||
num_workers=2,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
multiprocessing_context="fork",
|
||||
)
|
||||
|
||||
for _ in dl:
|
||||
break
|
||||
"""])
|
||||
|
||||
|
||||
class NamedTupleDataset(Dataset):
|
||||
|
|
|
|||
|
|
@ -946,6 +946,18 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
|||
else:
|
||||
self._data_queue = self._worker_result_queue
|
||||
|
||||
# In some rare cases, persistent workers (daemonic processes)
|
||||
# would be terminated before `__del__` of iterator is invoked
|
||||
# when main process exits
|
||||
# It would cause failure when pin_memory_thread tries to read
|
||||
# corrupted data from worker_result_queue
|
||||
# atexit is used to shutdown thread and child processes in the
|
||||
# right sequence before main process exits
|
||||
if self._persistent_workers and self._pin_memory:
|
||||
import atexit
|
||||
for w in self._workers:
|
||||
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
|
||||
|
||||
# .pid can be None only before process is spawned (not the case, so ignore)
|
||||
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
|
||||
_utils.signal_handling._set_SIGCHLD_handler()
|
||||
|
|
@ -1333,5 +1345,14 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
|||
# we kill the worker.
|
||||
w.terminate()
|
||||
|
||||
# staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
|
||||
@staticmethod
|
||||
def _clean_up_worker(w):
|
||||
try:
|
||||
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
|
||||
finally:
|
||||
if w.is_alive():
|
||||
w.terminate()
|
||||
|
||||
def __del__(self):
|
||||
self._shutdown_workers()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user