mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
261 lines
8.8 KiB
Python
261 lines
8.8 KiB
Python
import torch
|
|
import torch.multiprocessing as multiprocessing
|
|
from .sampler import SequentialSampler, RandomSampler
|
|
import collections
|
|
import math
|
|
import sys
|
|
import traceback
|
|
import threading
|
|
if sys.version_info[0] == 2:
|
|
import Queue as queue
|
|
else:
|
|
import queue
|
|
|
|
|
|
class ExceptionWrapper(object):
|
|
"Wraps an exception plus traceback to communicate across threads"
|
|
|
|
def __init__(self, exc_info):
|
|
self.exc_type = exc_info[0]
|
|
self.exc_msg = "".join(traceback.format_exception(*exc_info))
|
|
|
|
|
|
def _worker_loop(dataset, index_queue, data_queue, collate_fn):
|
|
torch.set_num_threads(1)
|
|
while True:
|
|
r = index_queue.get()
|
|
if r is None:
|
|
data_queue.put(None)
|
|
break
|
|
idx, batch_indices = r
|
|
try:
|
|
samples = collate_fn([dataset[i] for i in batch_indices])
|
|
except Exception:
|
|
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
|
|
else:
|
|
data_queue.put((idx, samples))
|
|
|
|
|
|
def _pin_memory_loop(in_queue, out_queue, done_event):
|
|
while True:
|
|
try:
|
|
r = in_queue.get()
|
|
except:
|
|
if done_event.is_set():
|
|
return
|
|
raise
|
|
if r is None:
|
|
break
|
|
if isinstance(r[1], ExceptionWrapper):
|
|
out_queue.put(r)
|
|
continue
|
|
idx, batch = r
|
|
try:
|
|
batch = pin_memory_batch(batch)
|
|
except Exception:
|
|
out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
|
|
else:
|
|
out_queue.put((idx, batch))
|
|
|
|
|
|
def default_collate(batch):
|
|
"Puts each data field into a tensor with outer dimension batch size"
|
|
if sys.version_info[0] < 3:
|
|
# handle the case of unicode in python 2.7
|
|
if isinstance(batch[0], unicode):
|
|
return batch
|
|
if torch.is_tensor(batch[0]):
|
|
return torch.stack(batch, 0)
|
|
elif type(batch[0]).__module__ == 'numpy': # this allows to not import numpy
|
|
return torch.stack([torch.from_numpy(b) for b in batch], 0)
|
|
elif isinstance(batch[0], int):
|
|
return torch.LongTensor(batch)
|
|
elif isinstance(batch[0], float):
|
|
return torch.DoubleTensor(batch)
|
|
elif isinstance(batch[0], str):
|
|
return batch
|
|
elif isinstance(batch[0], collections.Iterable):
|
|
# if each batch element is not a tensor, then it should be a tuple
|
|
# of tensors; in that case we collate each element in the tuple
|
|
transposed = zip(*batch)
|
|
return [default_collate(samples) for samples in transposed]
|
|
|
|
raise TypeError(("batch must contain tensors, numbers, or lists; found {}"
|
|
.format(type(batch[0]))))
|
|
|
|
|
|
def pin_memory_batch(batch):
|
|
if torch.is_tensor(batch):
|
|
return batch.pin_memory()
|
|
elif isinstance(batch, collections.Iterable):
|
|
return [pin_memory_batch(sample) for sample in batch]
|
|
else:
|
|
return batch
|
|
|
|
|
|
class DataLoaderIter(object):
|
|
"Iterates once over the DataLoader's dataset, as specified by the sampler"
|
|
|
|
def __init__(self, loader):
|
|
self.dataset = loader.dataset
|
|
self.batch_size = loader.batch_size
|
|
self.collate_fn = loader.collate_fn
|
|
self.sampler = loader.sampler
|
|
self.num_workers = loader.num_workers
|
|
self.pin_memory = loader.pin_memory
|
|
self.done_event = threading.Event()
|
|
|
|
self.samples_remaining = len(self.sampler)
|
|
self.sample_iter = iter(self.sampler)
|
|
|
|
if self.num_workers > 0:
|
|
self.index_queue = multiprocessing.SimpleQueue()
|
|
self.data_queue = multiprocessing.SimpleQueue()
|
|
self.batches_outstanding = 0
|
|
self.shutdown = False
|
|
self.send_idx = 0
|
|
self.rcvd_idx = 0
|
|
self.reorder_dict = {}
|
|
|
|
self.workers = [
|
|
multiprocessing.Process(
|
|
target=_worker_loop,
|
|
args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn))
|
|
for _ in range(self.num_workers)]
|
|
|
|
for w in self.workers:
|
|
w.daemon = True # ensure that the worker exits on process exit
|
|
w.start()
|
|
|
|
if self.pin_memory:
|
|
in_data = self.data_queue
|
|
self.data_queue = queue.Queue()
|
|
self.pin_thread = threading.Thread(
|
|
target=_pin_memory_loop,
|
|
args=(in_data, self.data_queue, self.done_event))
|
|
self.pin_thread.daemon = True
|
|
self.pin_thread.start()
|
|
|
|
# prime the prefetch loop
|
|
for _ in range(2 * self.num_workers):
|
|
self._put_indices()
|
|
|
|
def __len__(self):
|
|
return int(math.ceil(len(self.sampler) / float(self.batch_size)))
|
|
|
|
def __next__(self):
|
|
if self.num_workers == 0:
|
|
# same-process loading
|
|
if self.samples_remaining == 0:
|
|
raise StopIteration
|
|
indices = self._next_indices()
|
|
batch = self.collate_fn([self.dataset[i] for i in indices])
|
|
if self.pin_memory:
|
|
batch = pin_memory_batch(batch)
|
|
return batch
|
|
|
|
# check if the next sample has already been generated
|
|
if self.rcvd_idx in self.reorder_dict:
|
|
batch = self.reorder_dict.pop(self.rcvd_idx)
|
|
return self._process_next_batch(batch)
|
|
|
|
if self.batches_outstanding == 0:
|
|
self._shutdown_workers()
|
|
raise StopIteration
|
|
|
|
while True:
|
|
assert (not self.shutdown and self.batches_outstanding > 0)
|
|
idx, batch = self.data_queue.get()
|
|
self.batches_outstanding -= 1
|
|
if idx != self.rcvd_idx:
|
|
# store out-of-order samples
|
|
self.reorder_dict[idx] = batch
|
|
continue
|
|
return self._process_next_batch(batch)
|
|
|
|
next = __next__ # Python 2 compatibility
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def _next_indices(self):
|
|
batch_size = min(self.samples_remaining, self.batch_size)
|
|
batch = [next(self.sample_iter) for _ in range(batch_size)]
|
|
self.samples_remaining -= len(batch)
|
|
return batch
|
|
|
|
def _put_indices(self):
|
|
assert self.batches_outstanding < 2 * self.num_workers
|
|
if self.samples_remaining > 0:
|
|
self.index_queue.put((self.send_idx, self._next_indices()))
|
|
self.batches_outstanding += 1
|
|
self.send_idx += 1
|
|
|
|
def _process_next_batch(self, batch):
|
|
self.rcvd_idx += 1
|
|
self._put_indices()
|
|
if isinstance(batch, ExceptionWrapper):
|
|
raise batch.exc_type(batch.exc_msg)
|
|
return batch
|
|
|
|
def __getstate__(self):
|
|
# TODO: add limited pickling support for sharing an iterator
|
|
# across multiple threads for HOGWILD.
|
|
# Probably the best way to do this is by moving the sample pushing
|
|
# to a separate thread and then just sharing the data queue
|
|
# but signalling the end is tricky without a non-blocking API
|
|
raise NotImplementedError("DataLoaderIterator cannot be pickled")
|
|
|
|
def _shutdown_workers(self):
|
|
if not self.shutdown:
|
|
self.shutdown = True
|
|
self.done_event.set()
|
|
for _ in self.workers:
|
|
self.index_queue.put(None)
|
|
|
|
def __del__(self):
|
|
if self.num_workers > 0:
|
|
self._shutdown_workers()
|
|
|
|
|
|
class DataLoader(object):
|
|
"""
|
|
Data loader. Combines a dataset and a sampler, and provides
|
|
single- or multi-process iterators over the dataset.
|
|
|
|
Arguments:
|
|
dataset (Dataset): dataset from which to load the data.
|
|
batch_size (int, optional): how many samples per batch to load
|
|
(default: 1).
|
|
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
|
at every epoch (default: False).
|
|
sampler (Sampler, optional): defines the strategy to draw samples from
|
|
the dataset. If specified, the ``shuffle`` argument is ignored.
|
|
num_workers (int, optional): how many subprocesses to use for data
|
|
loading. 0 means that the data will be loaded in the main process
|
|
(default: 0)
|
|
collate_fn (callable, optional)
|
|
pin_memory (bool, optional)
|
|
"""
|
|
|
|
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
|
|
num_workers=0, collate_fn=default_collate, pin_memory=False):
|
|
self.dataset = dataset
|
|
self.batch_size = batch_size
|
|
self.num_workers = num_workers
|
|
self.collate_fn = collate_fn
|
|
self.pin_memory = pin_memory
|
|
|
|
if sampler is not None:
|
|
self.sampler = sampler
|
|
elif shuffle:
|
|
self.sampler = RandomSampler(dataset)
|
|
elif not shuffle:
|
|
self.sampler = SequentialSampler(dataset)
|
|
|
|
def __iter__(self):
|
|
return DataLoaderIter(self)
|
|
|
|
def __len__(self):
|
|
return int(math.ceil(len(self.sampler) / float(self.batch_size)))
|