Commit Graph

10 Commits

Author SHA1 Message Date
Donna Choi
8c07a98adc Error out of default_collate for lists of unequal size (#38492)
Summary:
Fix issue https://github.com/pytorch/pytorch/issues/23141#

In the below example ```default_collate``` collates each element of the list. Since the second element isn't present in all samples, it is discarded:
```
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np

class CustomDataset(Dataset):
    def __len__(self):
        return 2

    def __getitem__(self, idx):
        tmp = {
            "foo": np.array([1, 2, 3]),
            "bar": ["X"] * (idx+1),
        }

        return tmp

training = CustomDataset()

for batch in DataLoader(training, batch_size=2):
    print(batch)
```
Yields
```
{
  'foo': tensor(
    [
      [1, 2, 3],
      [1, 2, 3]
    ]
  ),
  'bar': [
      ('X', 'X'),
    ]
}
```

Based on discussion in the issue, it seems the best course of action is to error out in this case. This seems consistent with what is done for tensor elements, as seen in [TensorShape.cpp line 1066](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp#L1060) which is called when ```torch.stack``` is called. In this PR, I introduce a similar message to error out for lists.

SsnL
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38492

Differential Revision: D21620396

Pulled By: ezyang

fbshipit-source-id: 17f59fbb1ed1f0d9b2185c95b9ebe55ece701b0c
2020-05-18 14:53:33 -07:00
Nathan Goldbaum
f522bde121 Replace references to _DataLoaderIter with _BaseDataLoaderIter (#27105)
Summary:
Back in April, malmaud added type annotations for `dataloader.py`. However, at about the same time, SsnL in https://github.com/pytorch/pytorch/issues/19228 replaced `_DataLoaderIter` with `_BaseDataLoaderIter` and two subclasses, `_SingleProcessDataLoaderIter`, and `_MultiProcessingDataLoaderIter`. However - probably because these changes happened in parallel at roughly the same time, the type stubs and several other references in the codebase were never updated to match this refactoring.

I've gone ahead and done the updates to reflect the refactoring in https://github.com/pytorch/pytorch/issues/19228, which fixes the specific type stub/impelementation mismatch pointed out in https://github.com/pytorch/pytorch/issues/26673, although not the broader problem that pytorch doesn't have a test to make sure that the `.pyi` type stub files match the real API defined in `.py` files.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27105

Differential Revision: D17813641

Pulled By: ezyang

fbshipit-source-id: ed7ac025c8d6ad3f298dd073347ec83bb4b6600c
2019-10-08 12:09:02 -07:00
SsnL
df9d8f9032 Fix no auto batching bugs: cannot bulk load; not work with namedtuple (#26065)
Summary:
see title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26065

Differential Revision: D17392851

Pulled By: soumith

fbshipit-source-id: 468cd41c8e03d689ff2e0261d948e28daad6bfaf
2019-09-16 07:22:31 -07:00
Tongzhou Wang
058beae411 Add IterableDataset (#19228)
Summary:
This is a modified version of https://github.com/pytorch/pytorch/pull/14705 since commit structure for that PR is quite messy.

1. Add `IterableDataset`.
3. So we have 2 data loader mods: `Iterable` and `Map`.

    1. `Iterable` if the `dataset` is an instance of `IterableDataset`
    2. `Map` o.w.

3. Add better support for non-batch loading (i.e., `batch_size=None` and `batch_sampler=None`). This is useful in doing things like bulk loading.
3. Refactor `DataLoaderIter` into two classes, `_SingleProcessDataLoaderIter` and `_MultiProcessingDataLoaderIter`. Rename some methods to be more generic, e.g., `get_batch` -> `get_data`.
4. Add `torch.utils.data.get_worker_info` which returns worker information in a worker proc (e.g., worker id, dataset obj copy, etc.) and can be used in `IterableDataset.__iter__` and `worker_init_fn` to do per-worker configuration.
5. Add `ChainDataset`, which is the analog of `ConcatDataset` for `IterableDataset`.
7. Import torch.utils.data in `torch/__init__.py`
9. data loader examples and documentations
10. Use `get_worker_info` to detect whether we are in a worker process in `default_collate`

Closes https://github.com/pytorch/pytorch/issues/17909, https://github.com/pytorch/pytorch/issues/18096, https://github.com/pytorch/pytorch/issues/19946, and some of https://github.com/pytorch/pytorch/issues/13023
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19228

Reviewed By: bddppq

Differential Revision: D15058152

fbshipit-source-id: 9e081a901a071d7e4502b88054a34b450ab5ddde
2019-06-20 20:12:44 -07:00
Eskil Jörgensen
8042edcdb1 Make pin_memory and default_collate preserve namedtuples (#16440)
Summary:
Open issue: https://github.com/pytorch/pytorch/issues/3281
Corresponding PR (conflict): https://github.com/pytorch/pytorch/pull/4577

Another open issue: https://github.com/pytorch/pytorch/issues/14613
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16440

Differential Revision: D14020901

Pulled By: ezyang

fbshipit-source-id: 4abe817fc43c281a510715d311bad544511995d3
2019-02-11 08:47:33 -08:00
Christoph
2a45050fdc Concatenate directly into shared memory when constructing batches for numpy (#14534)
Summary:
Since #1323 tensors are shared with shared memory, but this feature is not active for numpy.
This PR fix this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14534

Differential Revision: D13561649

Pulled By: soumith

fbshipit-source-id: b6bc9e99fb91e8b675c2ef131fba9fa11c1647c0
2018-12-29 17:51:02 -08:00
SsnL
fb22f76eb6 default_collate should collate bool list to byte tensors (#14669)
Summary:
Based on #15331 . Review only the last commit.

Fixes https://github.com/pytorch/pytorch/issues/14507.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14669

Reviewed By: ezyang

Differential Revision: D13528725

Pulled By: soumith

fbshipit-source-id: f12f1ac1c4ff2a3ddd6877c0c096a5da3a1ffa3c
2018-12-28 12:26:46 -08:00
SsnL
9217bde807 Refactor dataloader.py (#15331)
Summary:
Same as #14668, and was approved there.

ailzhang , please apply this patch to Horizon's `data_streamer.py`: https://gist.github.com/SsnL/020fdb3d6b7016d81b6ba1d04cc41459 Thank you!

Below is the original description at #14668:

As I am working on tasks in https://github.com/pytorch/pytorch/issues/13023, I realized how unreadable the code is because all functions to be run in multiprocessing must be at top global level. Adding more functionalities to `dataloader.py` will only make things worse.

So in this PR, I refactor `dataloader.py` and move much of it into `data._utils`. E.g., the `_worker_loop` and related methods are now in `data._utils.worker`, signal handling code in `data._utils.signal_handling`, collating code in `data._utils.collate`, etc. This split, IMHO, makes code much clearer. I will base my future changes to DataLoader on top of this.

No functionality is changed, except that  I added `torch._six.queue`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15331

Reviewed By: yf225

Differential Revision: D13503120

Pulled By: ailzhang

fbshipit-source-id: 94df16b4d80ad1102c437cde0d5a2e62cffe1f8e
2018-12-19 12:36:03 -08:00
Ailing Zhang
38eb1beff5 Revert D13289919: [pytorch][PR] [DataLoader] Refactor dataloader.py
Differential Revision:
D13289919

Original commit changeset: d701bc7bb48f

fbshipit-source-id: c350c491fefa98a0a7c0cf22cb832e78aeb15c3d
2018-12-04 20:25:16 -08:00
SsnL
16558a1e9d Refactor dataloader.py (#14668)
Summary:
As I am working on tasks in https://github.com/pytorch/pytorch/issues/13023, I realized how unreadable the code is because all functions to be run in multiprocessing must be at top global level. Adding more functionalities to `dataloader.py` will only make things worse.

So in this PR, I refactor `dataloader.py` and move much of it into `data._utils`. E.g., the `_worker_loop` and related methods are now in `data._utils.worker`, signal handling code in `data._utils.signal_handling`, collating code in `data._utils.collate`, etc. This split, IMHO, makes code much clearer. I will base my future changes to DataLoader on top of this.

No functionality is changed, except that  I added `torch._six.queue`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14668

Reviewed By: soumith

Differential Revision: D13289919

Pulled By: ailzhang

fbshipit-source-id: d701bc7bb48f5dd7b163b5be941a9d27eb277a4c
2018-12-04 09:53:41 -08:00