mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[DateLoader] more clearly expose 'default_collate' and 'default_convert' to users (#69862)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69862 Fixes #69445 cc SsnL VitalyFedyunin ejguan NivekT Test Plan: Imported from OSS Reviewed By: ejguan, ngimel Differential Revision: D33068792 Pulled By: NivekT fbshipit-source-id: ef9791acdc23d014b8761fa7420062d454ce8969
This commit is contained in:
parent
1188d89a1d
commit
b67eaec853
|
|
@ -421,6 +421,8 @@ Example::
|
|||
.. autoclass:: ConcatDataset
|
||||
.. autoclass:: ChainDataset
|
||||
.. autoclass:: Subset
|
||||
.. autofunction:: torch.utils.data.default_collate
|
||||
.. autofunction:: torch.utils.data.default_convert
|
||||
.. autofunction:: torch.utils.data.get_worker_info
|
||||
.. autofunction:: torch.utils.data.random_split
|
||||
.. autoclass:: torch.utils.data.Sampler
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ from torch.utils.data.dataloader import (
|
|||
DataLoader,
|
||||
_DatasetKind,
|
||||
get_worker_info,
|
||||
default_collate,
|
||||
default_convert,
|
||||
)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data._decorator import (
|
||||
|
|
@ -58,6 +60,8 @@ __all__ = ['BatchSampler',
|
|||
'_DatasetKind',
|
||||
'argument_validation',
|
||||
'communication',
|
||||
'default_collate',
|
||||
'default_convert',
|
||||
'functional_datapipe',
|
||||
'get_worker_info',
|
||||
'guaranteed_datapipes_determinism',
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ collate samples fetched from dataset into Tensor(s).
|
|||
|
||||
These **needs** to be in global scope since Py2 doesn't support serializing
|
||||
static methods.
|
||||
|
||||
`default_collate` and `default_convert` are exposed to users via 'dataloader.py'.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
|
@ -14,7 +16,36 @@ np_str_obj_array_pattern = re.compile(r'[SaUO]')
|
|||
|
||||
|
||||
def default_convert(data):
|
||||
r"""Converts each NumPy array data field into a tensor"""
|
||||
r"""
|
||||
Function that converts each NumPy array element into a :class:`torch.Tensor`. If the input is a `Sequence`,
|
||||
`Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`.
|
||||
If the input is not an NumPy array, it is left unchanged.
|
||||
This is used as the default function for collation when both `batch_sampler` and
|
||||
`batch_size` are NOT defined in :class:`~torch.utils.data.DataLoader`.
|
||||
|
||||
The general input type to output type mapping is similar to that
|
||||
of :func:`~torch.utils.data.default_collate`. See the description there for more details.
|
||||
|
||||
Args:
|
||||
data: a single data point to be converted
|
||||
|
||||
Examples:
|
||||
>>> # Example with `int`
|
||||
>>> default_convert(0)
|
||||
0
|
||||
>>> # Example with NumPy array
|
||||
>>> default_convert(np.array([0, 1]))
|
||||
tensor([0, 1])
|
||||
>>> # Example with NamedTuple
|
||||
>>> Point = namedtuple('Point', ['x', 'y'])
|
||||
>>> default_convert(Point(0, 0))
|
||||
Point(x=0, y=0)
|
||||
>>> default_convert(Point(np.array(0), np.array(0)))
|
||||
Point(x=tensor(0), y=tensor(0))
|
||||
>>> # Example with List
|
||||
>>> default_convert([np.array([0, 1]), np.array([2, 3])])
|
||||
[tensor([0, 1]), tensor([2, 3])]
|
||||
"""
|
||||
elem_type = type(data)
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data
|
||||
|
|
@ -51,8 +82,49 @@ default_collate_err_msg_format = (
|
|||
|
||||
|
||||
def default_collate(batch):
|
||||
r"""Puts each data field into a tensor with outer dimension batch size"""
|
||||
r"""
|
||||
Function that takes in a batch of data and puts the elements within the batch
|
||||
into a tensor with an additional outer dimension - batch size. The exact output type can be
|
||||
a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a
|
||||
Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type.
|
||||
This is used as the default function for collation when
|
||||
`batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`.
|
||||
|
||||
Here is the general input type (based on the type of the element within the batch) to output type mapping:
|
||||
* :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size)
|
||||
* NumPy Arrays -> :class:`torch.Tensor`
|
||||
* `float` -> :class:`torch.Tensor`
|
||||
* `int` -> :class:`torch.Tensor`
|
||||
* `str` -> `str` (unchanged)
|
||||
* `bytes` -> `bytes` (unchanged)
|
||||
* `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]`
|
||||
* `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]), default_collate([V2_1, V2_2, ...]), ...]`
|
||||
* `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]), default_collate([V2_1, V2_2, ...]), ...]`
|
||||
|
||||
Args:
|
||||
batch: a single batch to be collated
|
||||
|
||||
Examples:
|
||||
>>> # Example with a batch of `int`s:
|
||||
>>> default_collate([0, 1, 2, 3])
|
||||
tensor([0, 1, 2, 3])
|
||||
>>> # Example with a batch of `str`s:
|
||||
>>> default_collate(['a', 'b', 'c'])
|
||||
['a', 'b', 'c']
|
||||
>>> # Example with `Map` inside the batch:
|
||||
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]
|
||||
{'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
|
||||
>>> # Example with `NamedTuple` inside the batch:
|
||||
>>> Point = namedtuple('Point', ['x', 'y'])
|
||||
>>> default_collate([Point(0, 0), Point(1, 1)])
|
||||
Point(x=tensor([0, 1]), y=tensor([0, 1]))
|
||||
>>> # Example with `Tuple` inside the batch:
|
||||
>>> default_collate([(0, 1), (2, 3)])
|
||||
[tensor([0, 2]), tensor([1, 3])]
|
||||
>>> # Example with `List` inside the batch:
|
||||
>>> default_collate([[0, 1], [2, 3]])
|
||||
[tensor([0, 2]), tensor([1, 3])]
|
||||
"""
|
||||
elem = batch[0]
|
||||
elem_type = type(elem)
|
||||
if isinstance(elem, torch.Tensor):
|
||||
|
|
|
|||
|
|
@ -33,15 +33,17 @@ _worker_init_fn_t = Callable[[int], None]
|
|||
_collate_fn_t = Callable[[List[T]], Any]
|
||||
|
||||
|
||||
# This function used to be defined in this file. However, it was moved to
|
||||
# These functions used to be defined in this file. However, it was moved to
|
||||
# _utils/collate.py. Although it is rather hard to access this from user land
|
||||
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
|
||||
# probably is user code out there using it. This aliasing maintains BC in this
|
||||
# aspect.
|
||||
default_collate: _collate_fn_t = _utils.collate.default_collate
|
||||
default_convert = _utils.collate.default_convert
|
||||
|
||||
get_worker_info = _utils.worker.get_worker_info
|
||||
|
||||
|
||||
class _DatasetKind(object):
|
||||
Map = 0
|
||||
Iterable = 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user