mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Same as #14668, and was approved there. ailzhang , please apply this patch to Horizon's `data_streamer.py`: https://gist.github.com/SsnL/020fdb3d6b7016d81b6ba1d04cc41459 Thank you! Below is the original description at #14668: As I am working on tasks in https://github.com/pytorch/pytorch/issues/13023, I realized how unreadable the code is because all functions to be run in multiprocessing must be at top global level. Adding more functionalities to `dataloader.py` will only make things worse. So in this PR, I refactor `dataloader.py` and move much of it into `data._utils`. E.g., the `_worker_loop` and related methods are now in `data._utils.worker`, signal handling code in `data._utils.signal_handling`, collating code in `data._utils.collate`, etc. This split, IMHO, makes code much clearer. I will base my future changes to DataLoader on top of this. No functionality is changed, except that I added `torch._six.queue`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15331 Reviewed By: yf225 Differential Revision: D13503120 Pulled By: ailzhang fbshipit-source-id: 94df16b4d80ad1102c437cde0d5a2e62cffe1f8e
69 lines
2.6 KiB
Python
69 lines
2.6 KiB
Python
r""""Contains definitions of the methods used by the _DataLoaderIter workers to
|
|
collate samples fetched from dataset into Tensor(s).
|
|
|
|
These **needs** to be in global scope since Py2 doesn't support serializing
|
|
static methods.
|
|
"""
|
|
|
|
import torch
|
|
import re
|
|
from torch._six import container_abcs, string_classes, int_classes
|
|
|
|
_use_shared_memory = False
|
|
r"""Whether to use shared memory in default_collate"""
|
|
|
|
np_str_obj_array_pattern = re.compile(r'[SaUO]')
|
|
|
|
error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}"
|
|
|
|
numpy_type_map = {
|
|
'float64': torch.DoubleTensor,
|
|
'float32': torch.FloatTensor,
|
|
'float16': torch.HalfTensor,
|
|
'int64': torch.LongTensor,
|
|
'int32': torch.IntTensor,
|
|
'int16': torch.ShortTensor,
|
|
'int8': torch.CharTensor,
|
|
'uint8': torch.ByteTensor,
|
|
}
|
|
|
|
|
|
def default_collate(batch):
|
|
r"""Puts each data field into a tensor with outer dimension batch size"""
|
|
|
|
elem_type = type(batch[0])
|
|
if isinstance(batch[0], torch.Tensor):
|
|
out = None
|
|
if _use_shared_memory:
|
|
# If we're in a background process, concatenate directly into a
|
|
# shared memory tensor to avoid an extra copy
|
|
numel = sum([x.numel() for x in batch])
|
|
storage = batch[0].storage()._new_shared(numel)
|
|
out = batch[0].new(storage)
|
|
return torch.stack(batch, 0, out=out)
|
|
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
|
and elem_type.__name__ != 'string_':
|
|
elem = batch[0]
|
|
if elem_type.__name__ == 'ndarray':
|
|
# array of string classes and object
|
|
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
|
raise TypeError(error_msg_fmt.format(elem.dtype))
|
|
|
|
return torch.stack([torch.from_numpy(b) for b in batch], 0)
|
|
if elem.shape == (): # scalars
|
|
py_type = float if elem.dtype.name.startswith('float') else int
|
|
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
|
|
elif isinstance(batch[0], int_classes):
|
|
return torch.LongTensor(batch)
|
|
elif isinstance(batch[0], float):
|
|
return torch.DoubleTensor(batch)
|
|
elif isinstance(batch[0], string_classes):
|
|
return batch
|
|
elif isinstance(batch[0], container_abcs.Mapping):
|
|
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
|
|
elif isinstance(batch[0], container_abcs.Sequence):
|
|
transposed = zip(*batch)
|
|
return [default_collate(samples) for samples in transposed]
|
|
|
|
raise TypeError((error_msg_fmt.format(type(batch[0]))))
|