mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: `default_collate`, `default_convert`, and `pin_memory` convert sequences into lists. I believe they should keep the original type when possible (e.g., I have a class that inherits from `list`, which comes from a 3rd party library that I can't change, and provides extra functionality). Note it's easy to do when the type supports an iterable in its creation but it's not always the case (e.g., `range`). Even though this can be accomplished if using a custom `default_collate`/`default_convert`, 1) this is behavior they should support out-of-the-box IMHO, and 2) `pin_memory` still does it. cc VitalyFedyunin ejguan NivekT Pull Request resolved: https://github.com/pytorch/pytorch/pull/68779 Reviewed By: wenleix Differential Revision: D32651129 Pulled By: ejguan fbshipit-source-id: 17c390934bacc0e4ead060469cf15dde815550b4
109 lines
4.7 KiB
Python
109 lines
4.7 KiB
Python
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
|
|
collate samples fetched from dataset into Tensor(s).
|
|
|
|
These **needs** to be in global scope since Py2 doesn't support serializing
|
|
static methods.
|
|
"""
|
|
|
|
import torch
|
|
import re
|
|
import collections
|
|
from torch._six import string_classes
|
|
|
|
np_str_obj_array_pattern = re.compile(r'[SaUO]')
|
|
|
|
|
|
def default_convert(data):
|
|
r"""Converts each NumPy array data field into a tensor"""
|
|
elem_type = type(data)
|
|
if isinstance(data, torch.Tensor):
|
|
return data
|
|
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
|
and elem_type.__name__ != 'string_':
|
|
# array of string classes and object
|
|
if elem_type.__name__ == 'ndarray' \
|
|
and np_str_obj_array_pattern.search(data.dtype.str) is not None:
|
|
return data
|
|
return torch.as_tensor(data)
|
|
elif isinstance(data, collections.abc.Mapping):
|
|
try:
|
|
return elem_type({key: default_convert(data[key]) for key in data})
|
|
except TypeError:
|
|
# The mapping type may not support `__init__(iterable)`.
|
|
return {key: default_convert(data[key]) for key in data}
|
|
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
|
|
return elem_type(*(default_convert(d) for d in data))
|
|
elif isinstance(data, tuple):
|
|
return [default_convert(d) for d in data] # Backwards compatibility.
|
|
elif isinstance(data, collections.abc.Sequence) and not isinstance(data, string_classes):
|
|
try:
|
|
return elem_type([default_convert(d) for d in data])
|
|
except TypeError:
|
|
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
|
|
return [default_convert(d) for d in data]
|
|
else:
|
|
return data
|
|
|
|
|
|
default_collate_err_msg_format = (
|
|
"default_collate: batch must contain tensors, numpy arrays, numbers, "
|
|
"dicts or lists; found {}")
|
|
|
|
|
|
def default_collate(batch):
|
|
r"""Puts each data field into a tensor with outer dimension batch size"""
|
|
|
|
elem = batch[0]
|
|
elem_type = type(elem)
|
|
if isinstance(elem, torch.Tensor):
|
|
out = None
|
|
if torch.utils.data.get_worker_info() is not None:
|
|
# If we're in a background process, concatenate directly into a
|
|
# shared memory tensor to avoid an extra copy
|
|
numel = sum(x.numel() for x in batch)
|
|
storage = elem.storage()._new_shared(numel)
|
|
out = elem.new(storage)
|
|
return torch.stack(batch, 0, out=out)
|
|
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
|
and elem_type.__name__ != 'string_':
|
|
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
|
# array of string classes and object
|
|
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
|
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
|
|
|
return default_collate([torch.as_tensor(b) for b in batch])
|
|
elif elem.shape == (): # scalars
|
|
return torch.as_tensor(batch)
|
|
elif isinstance(elem, float):
|
|
return torch.tensor(batch, dtype=torch.float64)
|
|
elif isinstance(elem, int):
|
|
return torch.tensor(batch)
|
|
elif isinstance(elem, string_classes):
|
|
return batch
|
|
elif isinstance(elem, collections.abc.Mapping):
|
|
try:
|
|
return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
|
|
except TypeError:
|
|
# The mapping type may not support `__init__(iterable)`.
|
|
return {key: default_collate([d[key] for d in batch]) for key in elem}
|
|
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
|
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
|
|
elif isinstance(elem, collections.abc.Sequence):
|
|
# check to make sure that the elements in batch have consistent size
|
|
it = iter(batch)
|
|
elem_size = len(next(it))
|
|
if not all(len(elem) == elem_size for elem in it):
|
|
raise RuntimeError('each element in list of batch should be of equal size')
|
|
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
|
|
|
|
if isinstance(elem, tuple):
|
|
return [default_collate(samples) for samples in transposed] # Backwards compatibility.
|
|
else:
|
|
try:
|
|
return elem_type([default_collate(samples) for samples in transposed])
|
|
except TypeError:
|
|
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
|
|
return [default_collate(samples) for samples in transposed]
|
|
|
|
raise TypeError(default_collate_err_msg_format.format(elem_type))
|