pytorch/torch/utils/data/_utils/fetch.py
Xu Zhao 52a2b61203 Fix fetch function which breaks user code (#85099)
The [fastNLP](https://github.com/fastnlp/fastNLP/blob/v0.6.0/fastNLP/core/batch.py#L51) model uses DataSetGetter to fetch data from the dataset. The following code breaks because of https://github.com/pytorch/pytorch/pull/84301:

```
from fastNLP.io.pipe.qa import CMRC2018BertPipe
input_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".data", "cmrc2018-sim")
data_bundle = CMRC2018BertPipe().process_from_file(paths=input_dir)
data_bundle.rename_field('chars', 'words')
data_bundle.get_dataset('dev')
dataset = DataSetGetter(dataset, as_numpy)
dataiter = torch.utils.data.DataLoader(dataset=dataset)
for batch in dataiter:
    # data-processing...
```

This is because for the `DataSetGetter` class, the following condition holds:
```
# hasattr(dataset_getter, '__getitems__') == True
# dataset_getter.__getitems__ == None
```

This PR adds an additional check to make sure `__getitems__` is only called when it is not None.

This error was found by the torchbench nightly CI, original error stack trace:
```
ERROR: test_fastNLP_Bert_train_cuda (__main__.TestBenchmark)
----------------------------------------------------------------------
components._impl.workers.subprocess_rpc.ChildTraceException: Traceback (most recent call last):
  File "/home/circleci/project/components/_impl/workers/subprocess_rpc.py", line 470, in _run_block
    exec(  # noqa: P204
  File "<subprocess-worker>", line 35, in <module>
  File "<subprocess-worker>", line 12, in _run_in_worker_f
  File "/home/circleci/project/torchbenchmark/util/model.py", line 16, in __call__
    obj = type.__call__(cls, *args, **kwargs)
  File "/home/circleci/project/torchbenchmark/models/fastNLP_Bert/__init__.py", line 93, in __init__
    self.example_inputs = self._prefetch(example_inputs)
  File "/home/circleci/project/torchbenchmark/models/fastNLP_Bert/__init__.py", line 133, in _prefetch
    for batch_x, batch_y in example_inputs:
  File "/home/circleci/miniconda3/lib/python3.8/site-packages/fastNLP/core/batch.py", line 266, in __iter__
    for indices, batch_x, batch_y in self.dataiter:
  File "/home/circleci/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 681, in __next__
    data = self._next_data()
  File "/home/circleci/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 719, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/circleci/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 56, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
TypeError: 'NoneType' object is not callable
```

Full error log: https://app.circleci.com/pipelines/github/pytorch/benchmark/5143/workflows/0676f36d-0ab4-42bd-adb4-90e6b0df76d1/jobs/5293
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85099
Approved by: https://github.com/ejguan
2022-09-15 21:48:28 +00:00

62 lines
2.1 KiB
Python

r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch
data from an iterable-style or map-style dataset. This logic is shared in both
single- and multi-processing data loading.
"""
class _BaseDatasetFetcher(object):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
self.dataset = dataset
self.auto_collation = auto_collation
self.collate_fn = collate_fn
self.drop_last = drop_last
def fetch(self, possibly_batched_index):
raise NotImplementedError()
class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(
dataset, auto_collation, collate_fn, drop_last
)
self.dataset_iter = iter(dataset)
self.ended = False
def fetch(self, possibly_batched_index):
if self.ended:
raise StopIteration
if self.auto_collation:
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
self.ended = True
break
if len(data) == 0 or (
self.drop_last and len(data) < len(possibly_batched_index)
):
raise StopIteration
else:
data = next(self.dataset_iter)
return self.collate_fn(data)
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(
dataset, auto_collation, collate_fn, drop_last
)
def fetch(self, possibly_batched_index):
if self.auto_collation:
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
data = self.dataset.__getitems__(possibly_batched_index)
else:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)