pytorch/torch/utils/data/_utils/collate.py
Donna Choi 8c07a98adc Error out of default_collate for lists of unequal size (#38492)
Summary:
Fix issue https://github.com/pytorch/pytorch/issues/23141#

In the below example ```default_collate``` collates each element of the list. Since the second element isn't present in all samples, it is discarded:
```
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np

class CustomDataset(Dataset):
    def __len__(self):
        return 2

    def __getitem__(self, idx):
        tmp = {
            "foo": np.array([1, 2, 3]),
            "bar": ["X"] * (idx+1),
        }

        return tmp

training = CustomDataset()

for batch in DataLoader(training, batch_size=2):
    print(batch)
```
Yields
```
{
  'foo': tensor(
    [
      [1, 2, 3],
      [1, 2, 3]
    ]
  ),
  'bar': [
      ('X', 'X'),
    ]
}
```

Based on discussion in the issue, it seems the best course of action is to error out in this case. This seems consistent with what is done for tensor elements, as seen in [TensorShape.cpp line 1066](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp#L1060) which is called when ```torch.stack``` is called. In this PR, I introduce a similar message to error out for lists.

SsnL
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38492

Differential Revision: D21620396

Pulled By: ezyang

fbshipit-source-id: 17f59fbb1ed1f0d9b2185c95b9ebe55ece701b0c
2020-05-18 14:53:33 -07:00

87 lines
3.6 KiB
Python

r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
collate samples fetched from dataset into Tensor(s).
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
"""
import torch
import re
from torch._six import container_abcs, string_classes, int_classes
np_str_obj_array_pattern = re.compile(r'[SaUO]')
def default_convert(data):
r"""Converts each NumPy array data field into a tensor"""
elem_type = type(data)
if isinstance(data, torch.Tensor):
return data
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
# array of string classes and object
if elem_type.__name__ == 'ndarray' \
and np_str_obj_array_pattern.search(data.dtype.str) is not None:
return data
return torch.as_tensor(data)
elif isinstance(data, container_abcs.Mapping):
return {key: default_convert(data[key]) for key in data}
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
return elem_type(*(default_convert(d) for d in data))
elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes):
return [default_convert(d) for d in data]
else:
return data
default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}")
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int_classes):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, container_abcs.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, container_abcs.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))