import torch import torch.multiprocessing as multiprocessing from .sampler import SequentialSampler, RandomSampler import collections import math import sys import traceback import threading if sys.version_info[0] == 2: import Queue as queue else: import queue class ExceptionWrapper(object): "Wraps an exception plus traceback to communicate across threads" def __init__(self, exc_info): self.exc_type = exc_info[0] self.exc_msg = "".join(traceback.format_exception(*exc_info)) def _worker_loop(dataset, index_queue, data_queue, collate_fn): torch.set_num_threads(1) while True: r = index_queue.get() if r is None: data_queue.put(None) break idx, batch_indices = r try: samples = collate_fn([dataset[i] for i in batch_indices]) except Exception: data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) def _pin_memory_loop(in_queue, out_queue, done_event): while True: try: r = in_queue.get() except: if done_event.is_set(): return raise if r is None: break if isinstance(r[1], ExceptionWrapper): out_queue.put(r) continue idx, batch = r try: batch = pin_memory_batch(batch) except Exception: out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: out_queue.put((idx, batch)) def default_collate(batch): "Puts each data field into a tensor with outer dimension batch size" if torch.is_tensor(batch[0]): return torch.stack(batch, 0) elif type(batch[0]).__module__ == 'numpy': # this allows to not import numpy return torch.stack([torch.from_numpy(b) for b in batch], 0) elif isinstance(batch[0], int): return torch.LongTensor(batch) elif isinstance(batch[0], float): return torch.DoubleTensor(batch) elif isinstance(batch[0], str): return batch elif isinstance(batch[0], collections.Iterable): # if each batch element is not a tensor, then it should be a tuple # of tensors; in that case we collate each element in the tuple transposed = zip(*batch) return [default_collate(samples) for samples in transposed] raise TypeError(("batch must contain tensors, numbers, or lists; found {}" .format(type(batch[0])))) def pin_memory_batch(batch): if torch.is_tensor(batch): return batch.pin_memory() elif isinstance(batch, collections.Iterable): return [pin_memory_batch(sample) for sample in batch] else: return batch class DataLoaderIter(object): "Iterates once over the DataLoader's dataset, as specified by the sampler" def __init__(self, loader): self.dataset = loader.dataset self.batch_size = loader.batch_size self.collate_fn = loader.collate_fn self.sampler = loader.sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory self.done_event = threading.Event() self.samples_remaining = len(self.sampler) self.sample_iter = iter(self.sampler) if self.num_workers > 0: self.index_queue = multiprocessing.SimpleQueue() self.data_queue = multiprocessing.SimpleQueue() self.batches_outstanding = 0 self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn)) for _ in range(self.num_workers)] for w in self.workers: w.daemon = True # ensure that the worker exits on process exit w.start() if self.pin_memory: in_data = self.data_queue self.data_queue = queue.Queue() self.pin_thread = threading.Thread( target=_pin_memory_loop, args=(in_data, self.data_queue, self.done_event)) self.pin_thread.daemon = True self.pin_thread.start() # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices() def __len__(self): return int(math.ceil(len(self.sampler) / float(self.batch_size))) def __next__(self): if self.num_workers == 0: # same-process loading if self.samples_remaining == 0: raise StopIteration indices = self._next_indices() batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = pin_memory_batch(batch) return batch # check if the next sample has already been generated if self.rcvd_idx in self.reorder_dict: batch = self.reorder_dict.pop(self.rcvd_idx) return self._process_next_batch(batch) if self.batches_outstanding == 0: self._shutdown_workers() raise StopIteration while True: assert (not self.shutdown and self.batches_outstanding > 0) idx, batch = self.data_queue.get() self.batches_outstanding -= 1 if idx != self.rcvd_idx: # store out-of-order samples self.reorder_dict[idx] = batch continue return self._process_next_batch(batch) next = __next__ # Python 2 compatibility def __iter__(self): return self def _next_indices(self): batch_size = min(self.samples_remaining, self.batch_size) batch = [next(self.sample_iter) for _ in range(batch_size)] self.samples_remaining -= len(batch) return batch def _put_indices(self): assert self.batches_outstanding < 2 * self.num_workers if self.samples_remaining > 0: self.index_queue.put((self.send_idx, self._next_indices())) self.batches_outstanding += 1 self.send_idx += 1 def _process_next_batch(self, batch): self.rcvd_idx += 1 self._put_indices() if isinstance(batch, ExceptionWrapper): raise batch.exc_type(batch.exc_msg) return batch def __getstate__(self): # TODO: add limited pickling support for sharing an iterator # across multiple threads for HOGWILD. # Probably the best way to do this is by moving the sample pushing # to a separate thread and then just sharing the data queue # but signalling the end is tricky without a non-blocking API raise NotImplementedError("DataLoaderIterator cannot be pickled") def _shutdown_workers(self): if not self.shutdown: self.shutdown = True self.done_event.set() for _ in self.workers: self.index_queue.put(None) def __del__(self): if self.num_workers > 0: self._shutdown_workers() class DataLoader(object): """ Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset. Arguments: dataset (Dataset): dataset from which to load the data. batch_size (int, optional): how many samples per batch to load (default: 1). shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: False). sampler (Sampler, optional): defines the strategy to draw samples from the dataset. If specified, the ``shuffle`` argument is ignored. num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process (default: 0) collate_fn (callable, optional) pin_memory (bool, optional) """ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.collate_fn = collate_fn self.pin_memory = pin_memory if sampler is not None: self.sampler = sampler elif shuffle: self.sampler = RandomSampler(dataset) elif not shuffle: self.sampler = SequentialSampler(dataset) def __iter__(self): return DataLoaderIter(self) def __len__(self): return int(math.ceil(len(self.sampler) / float(self.batch_size)))