mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Back in April, malmaud added type annotations for `dataloader.py`. However, at about the same time, SsnL in https://github.com/pytorch/pytorch/issues/19228 replaced `_DataLoaderIter` with `_BaseDataLoaderIter` and two subclasses, `_SingleProcessDataLoaderIter`, and `_MultiProcessingDataLoaderIter`. However - probably because these changes happened in parallel at roughly the same time, the type stubs and several other references in the codebase were never updated to match this refactoring. I've gone ahead and done the updates to reflect the refactoring in https://github.com/pytorch/pytorch/issues/19228, which fixes the specific type stub/impelementation mismatch pointed out in https://github.com/pytorch/pytorch/issues/26673, although not the broader problem that pytorch doesn't have a test to make sure that the `.pyi` type stub files match the real API defined in `.py` files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/27105 Differential Revision: D17813641 Pulled By: ezyang fbshipit-source-id: ed7ac025c8d6ad3f298dd073347ec83bb4b6600c
82 lines
3.3 KiB
Python
82 lines
3.3 KiB
Python
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
|
|
collate samples fetched from dataset into Tensor(s).
|
|
|
|
These **needs** to be in global scope since Py2 doesn't support serializing
|
|
static methods.
|
|
"""
|
|
|
|
import torch
|
|
import re
|
|
from torch._six import container_abcs, string_classes, int_classes
|
|
|
|
np_str_obj_array_pattern = re.compile(r'[SaUO]')
|
|
|
|
|
|
def default_convert(data):
|
|
r"""Converts each NumPy array data field into a tensor"""
|
|
elem_type = type(data)
|
|
if isinstance(data, torch.Tensor):
|
|
return data
|
|
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
|
and elem_type.__name__ != 'string_':
|
|
# array of string classes and object
|
|
if elem_type.__name__ == 'ndarray' \
|
|
and np_str_obj_array_pattern.search(data.dtype.str) is not None:
|
|
return data
|
|
return torch.as_tensor(data)
|
|
elif isinstance(data, container_abcs.Mapping):
|
|
return {key: default_convert(data[key]) for key in data}
|
|
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
|
|
return elem_type(*(default_convert(d) for d in data))
|
|
elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes):
|
|
return [default_convert(d) for d in data]
|
|
else:
|
|
return data
|
|
|
|
|
|
default_collate_err_msg_format = (
|
|
"default_collate: batch must contain tensors, numpy arrays, numbers, "
|
|
"dicts or lists; found {}")
|
|
|
|
|
|
def default_collate(batch):
|
|
r"""Puts each data field into a tensor with outer dimension batch size"""
|
|
|
|
elem = batch[0]
|
|
elem_type = type(elem)
|
|
if isinstance(elem, torch.Tensor):
|
|
out = None
|
|
if torch.utils.data.get_worker_info() is not None:
|
|
# If we're in a background process, concatenate directly into a
|
|
# shared memory tensor to avoid an extra copy
|
|
numel = sum([x.numel() for x in batch])
|
|
storage = elem.storage()._new_shared(numel)
|
|
out = elem.new(storage)
|
|
return torch.stack(batch, 0, out=out)
|
|
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
|
and elem_type.__name__ != 'string_':
|
|
elem = batch[0]
|
|
if elem_type.__name__ == 'ndarray':
|
|
# array of string classes and object
|
|
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
|
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
|
|
|
return default_collate([torch.as_tensor(b) for b in batch])
|
|
elif elem.shape == (): # scalars
|
|
return torch.as_tensor(batch)
|
|
elif isinstance(elem, float):
|
|
return torch.tensor(batch, dtype=torch.float64)
|
|
elif isinstance(elem, int_classes):
|
|
return torch.tensor(batch)
|
|
elif isinstance(elem, string_classes):
|
|
return batch
|
|
elif isinstance(elem, container_abcs.Mapping):
|
|
return {key: default_collate([d[key] for d in batch]) for key in elem}
|
|
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
|
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
|
|
elif isinstance(elem, container_abcs.Sequence):
|
|
transposed = zip(*batch)
|
|
return [default_collate(samples) for samples in transposed]
|
|
|
|
raise TypeError(default_collate_err_msg_format.format(elem_type))
|