mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Fix issue https://github.com/pytorch/pytorch/issues/23141# In the below example ```default_collate``` collates each element of the list. Since the second element isn't present in all samples, it is discarded: ``` from torch.utils.data import Dataset from torch.utils.data import DataLoader import numpy as np class CustomDataset(Dataset): def __len__(self): return 2 def __getitem__(self, idx): tmp = { "foo": np.array([1, 2, 3]), "bar": ["X"] * (idx+1), } return tmp training = CustomDataset() for batch in DataLoader(training, batch_size=2): print(batch) ``` Yields ``` { 'foo': tensor( [ [1, 2, 3], [1, 2, 3] ] ), 'bar': [ ('X', 'X'), ] } ``` Based on discussion in the issue, it seems the best course of action is to error out in this case. This seems consistent with what is done for tensor elements, as seen in [TensorShape.cpp line 1066](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp#L1060) which is called when ```torch.stack``` is called. In this PR, I introduce a similar message to error out for lists. SsnL Pull Request resolved: https://github.com/pytorch/pytorch/pull/38492 Differential Revision: D21620396 Pulled By: ezyang fbshipit-source-id: 17f59fbb1ed1f0d9b2185c95b9ebe55ece701b0c
87 lines
3.6 KiB
Python
87 lines
3.6 KiB
Python
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
|
|
collate samples fetched from dataset into Tensor(s).
|
|
|
|
These **needs** to be in global scope since Py2 doesn't support serializing
|
|
static methods.
|
|
"""
|
|
|
|
import torch
|
|
import re
|
|
from torch._six import container_abcs, string_classes, int_classes
|
|
|
|
np_str_obj_array_pattern = re.compile(r'[SaUO]')
|
|
|
|
|
|
def default_convert(data):
|
|
r"""Converts each NumPy array data field into a tensor"""
|
|
elem_type = type(data)
|
|
if isinstance(data, torch.Tensor):
|
|
return data
|
|
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
|
and elem_type.__name__ != 'string_':
|
|
# array of string classes and object
|
|
if elem_type.__name__ == 'ndarray' \
|
|
and np_str_obj_array_pattern.search(data.dtype.str) is not None:
|
|
return data
|
|
return torch.as_tensor(data)
|
|
elif isinstance(data, container_abcs.Mapping):
|
|
return {key: default_convert(data[key]) for key in data}
|
|
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
|
|
return elem_type(*(default_convert(d) for d in data))
|
|
elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes):
|
|
return [default_convert(d) for d in data]
|
|
else:
|
|
return data
|
|
|
|
|
|
default_collate_err_msg_format = (
|
|
"default_collate: batch must contain tensors, numpy arrays, numbers, "
|
|
"dicts or lists; found {}")
|
|
|
|
|
|
def default_collate(batch):
|
|
r"""Puts each data field into a tensor with outer dimension batch size"""
|
|
|
|
elem = batch[0]
|
|
elem_type = type(elem)
|
|
if isinstance(elem, torch.Tensor):
|
|
out = None
|
|
if torch.utils.data.get_worker_info() is not None:
|
|
# If we're in a background process, concatenate directly into a
|
|
# shared memory tensor to avoid an extra copy
|
|
numel = sum([x.numel() for x in batch])
|
|
storage = elem.storage()._new_shared(numel)
|
|
out = elem.new(storage)
|
|
return torch.stack(batch, 0, out=out)
|
|
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
|
and elem_type.__name__ != 'string_':
|
|
elem = batch[0]
|
|
if elem_type.__name__ == 'ndarray':
|
|
# array of string classes and object
|
|
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
|
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
|
|
|
return default_collate([torch.as_tensor(b) for b in batch])
|
|
elif elem.shape == (): # scalars
|
|
return torch.as_tensor(batch)
|
|
elif isinstance(elem, float):
|
|
return torch.tensor(batch, dtype=torch.float64)
|
|
elif isinstance(elem, int_classes):
|
|
return torch.tensor(batch)
|
|
elif isinstance(elem, string_classes):
|
|
return batch
|
|
elif isinstance(elem, container_abcs.Mapping):
|
|
return {key: default_collate([d[key] for d in batch]) for key in elem}
|
|
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
|
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
|
|
elif isinstance(elem, container_abcs.Sequence):
|
|
# check to make sure that the elements in batch have consistent size
|
|
it = iter(batch)
|
|
elem_size = len(next(it))
|
|
if not all(len(elem) == elem_size for elem in it):
|
|
raise RuntimeError('each element in list of batch should be of equal size')
|
|
transposed = zip(*batch)
|
|
return [default_collate(samples) for samples in transposed]
|
|
|
|
raise TypeError(default_collate_err_msg_format.format(elem_type))
|