[DateLoader] more clearly expose 'default_collate' and 'default_convert' to users (#69862)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69862

Fixes #69445

cc SsnL VitalyFedyunin ejguan NivekT

Test Plan: Imported from OSS

Reviewed By: ejguan, ngimel

Differential Revision: D33068792

Pulled By: NivekT

fbshipit-source-id: ef9791acdc23d014b8761fa7420062d454ce8969
This commit is contained in:
Kevin Tse 2021-12-14 11:16:53 -08:00 committed by Facebook GitHub Bot
parent 1188d89a1d
commit b67eaec853
4 changed files with 83 additions and 3 deletions

View File

@ -421,6 +421,8 @@ Example::
.. autoclass:: ConcatDataset
.. autoclass:: ChainDataset
.. autoclass:: Subset
.. autofunction:: torch.utils.data.default_collate
.. autofunction:: torch.utils.data.default_convert
.. autofunction:: torch.utils.data.get_worker_info
.. autofunction:: torch.utils.data.random_split
.. autoclass:: torch.utils.data.Sampler

View File

@ -25,6 +25,8 @@ from torch.utils.data.dataloader import (
DataLoader,
_DatasetKind,
get_worker_info,
default_collate,
default_convert,
)
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data._decorator import (
@ -58,6 +60,8 @@ __all__ = ['BatchSampler',
'_DatasetKind',
'argument_validation',
'communication',
'default_collate',
'default_convert',
'functional_datapipe',
'get_worker_info',
'guaranteed_datapipes_determinism',

View File

@ -3,6 +3,8 @@ collate samples fetched from dataset into Tensor(s).
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
`default_collate` and `default_convert` are exposed to users via 'dataloader.py'.
"""
import torch
@ -14,7 +16,36 @@ np_str_obj_array_pattern = re.compile(r'[SaUO]')
def default_convert(data):
r"""Converts each NumPy array data field into a tensor"""
r"""
Function that converts each NumPy array element into a :class:`torch.Tensor`. If the input is a `Sequence`,
`Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`.
If the input is not an NumPy array, it is left unchanged.
This is used as the default function for collation when both `batch_sampler` and
`batch_size` are NOT defined in :class:`~torch.utils.data.DataLoader`.
The general input type to output type mapping is similar to that
of :func:`~torch.utils.data.default_collate`. See the description there for more details.
Args:
data: a single data point to be converted
Examples:
>>> # Example with `int`
>>> default_convert(0)
0
>>> # Example with NumPy array
>>> default_convert(np.array([0, 1]))
tensor([0, 1])
>>> # Example with NamedTuple
>>> Point = namedtuple('Point', ['x', 'y'])
>>> default_convert(Point(0, 0))
Point(x=0, y=0)
>>> default_convert(Point(np.array(0), np.array(0)))
Point(x=tensor(0), y=tensor(0))
>>> # Example with List
>>> default_convert([np.array([0, 1]), np.array([2, 3])])
[tensor([0, 1]), tensor([2, 3])]
"""
elem_type = type(data)
if isinstance(data, torch.Tensor):
return data
@ -51,8 +82,49 @@ default_collate_err_msg_format = (
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
r"""
Function that takes in a batch of data and puts the elements within the batch
into a tensor with an additional outer dimension - batch size. The exact output type can be
a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a
Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type.
This is used as the default function for collation when
`batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`.
Here is the general input type (based on the type of the element within the batch) to output type mapping:
* :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size)
* NumPy Arrays -> :class:`torch.Tensor`
* `float` -> :class:`torch.Tensor`
* `int` -> :class:`torch.Tensor`
* `str` -> `str` (unchanged)
* `bytes` -> `bytes` (unchanged)
* `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]`
* `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]), default_collate([V2_1, V2_2, ...]), ...]`
* `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]), default_collate([V2_1, V2_2, ...]), ...]`
Args:
batch: a single batch to be collated
Examples:
>>> # Example with a batch of `int`s:
>>> default_collate([0, 1, 2, 3])
tensor([0, 1, 2, 3])
>>> # Example with a batch of `str`s:
>>> default_collate(['a', 'b', 'c'])
['a', 'b', 'c']
>>> # Example with `Map` inside the batch:
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]
{'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
>>> # Example with `NamedTuple` inside the batch:
>>> Point = namedtuple('Point', ['x', 'y'])
>>> default_collate([Point(0, 0), Point(1, 1)])
Point(x=tensor([0, 1]), y=tensor([0, 1]))
>>> # Example with `Tuple` inside the batch:
>>> default_collate([(0, 1), (2, 3)])
[tensor([0, 2]), tensor([1, 3])]
>>> # Example with `List` inside the batch:
>>> default_collate([[0, 1], [2, 3]])
[tensor([0, 2]), tensor([1, 3])]
"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):

View File

@ -33,15 +33,17 @@ _worker_init_fn_t = Callable[[int], None]
_collate_fn_t = Callable[[List[T]], Any]
# This function used to be defined in this file. However, it was moved to
# These functions used to be defined in this file. However, it was moved to
# _utils/collate.py. Although it is rather hard to access this from user land
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
# probably is user code out there using it. This aliasing maintains BC in this
# aspect.
default_collate: _collate_fn_t = _utils.collate.default_collate
default_convert = _utils.collate.default_convert
get_worker_info = _utils.worker.get_worker_info
class _DatasetKind(object):
Map = 0
Iterable = 1