mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
The [fastNLP](https://github.com/fastnlp/fastNLP/blob/v0.6.0/fastNLP/core/batch.py#L51) model uses DataSetGetter to fetch data from the dataset. The following code breaks because of https://github.com/pytorch/pytorch/pull/84301: ``` from fastNLP.io.pipe.qa import CMRC2018BertPipe input_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".data", "cmrc2018-sim") data_bundle = CMRC2018BertPipe().process_from_file(paths=input_dir) data_bundle.rename_field('chars', 'words') data_bundle.get_dataset('dev') dataset = DataSetGetter(dataset, as_numpy) dataiter = torch.utils.data.DataLoader(dataset=dataset) for batch in dataiter: # data-processing... ``` This is because for the `DataSetGetter` class, the following condition holds: ``` # hasattr(dataset_getter, '__getitems__') == True # dataset_getter.__getitems__ == None ``` This PR adds an additional check to make sure `__getitems__` is only called when it is not None. This error was found by the torchbench nightly CI, original error stack trace: ``` ERROR: test_fastNLP_Bert_train_cuda (__main__.TestBenchmark) ---------------------------------------------------------------------- components._impl.workers.subprocess_rpc.ChildTraceException: Traceback (most recent call last): File "/home/circleci/project/components/_impl/workers/subprocess_rpc.py", line 470, in _run_block exec( # noqa: P204 File "<subprocess-worker>", line 35, in <module> File "<subprocess-worker>", line 12, in _run_in_worker_f File "/home/circleci/project/torchbenchmark/util/model.py", line 16, in __call__ obj = type.__call__(cls, *args, **kwargs) File "/home/circleci/project/torchbenchmark/models/fastNLP_Bert/__init__.py", line 93, in __init__ self.example_inputs = self._prefetch(example_inputs) File "/home/circleci/project/torchbenchmark/models/fastNLP_Bert/__init__.py", line 133, in _prefetch for batch_x, batch_y in example_inputs: File "/home/circleci/miniconda3/lib/python3.8/site-packages/fastNLP/core/batch.py", line 266, in __iter__ for indices, batch_x, batch_y in self.dataiter: File "/home/circleci/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 681, in __next__ data = self._next_data() File "/home/circleci/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 719, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "/home/circleci/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 56, in fetch data = self.dataset.__getitems__(possibly_batched_index) TypeError: 'NoneType' object is not callable ``` Full error log: https://app.circleci.com/pipelines/github/pytorch/benchmark/5143/workflows/0676f36d-0ab4-42bd-adb4-90e6b0df76d1/jobs/5293 Pull Request resolved: https://github.com/pytorch/pytorch/pull/85099 Approved by: https://github.com/ejguan
62 lines
2.1 KiB
Python
62 lines
2.1 KiB
Python
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch
|
|
data from an iterable-style or map-style dataset. This logic is shared in both
|
|
single- and multi-processing data loading.
|
|
"""
|
|
|
|
|
|
class _BaseDatasetFetcher(object):
|
|
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
|
self.dataset = dataset
|
|
self.auto_collation = auto_collation
|
|
self.collate_fn = collate_fn
|
|
self.drop_last = drop_last
|
|
|
|
def fetch(self, possibly_batched_index):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class _IterableDatasetFetcher(_BaseDatasetFetcher):
|
|
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
|
super(_IterableDatasetFetcher, self).__init__(
|
|
dataset, auto_collation, collate_fn, drop_last
|
|
)
|
|
self.dataset_iter = iter(dataset)
|
|
self.ended = False
|
|
|
|
def fetch(self, possibly_batched_index):
|
|
if self.ended:
|
|
raise StopIteration
|
|
|
|
if self.auto_collation:
|
|
data = []
|
|
for _ in possibly_batched_index:
|
|
try:
|
|
data.append(next(self.dataset_iter))
|
|
except StopIteration:
|
|
self.ended = True
|
|
break
|
|
if len(data) == 0 or (
|
|
self.drop_last and len(data) < len(possibly_batched_index)
|
|
):
|
|
raise StopIteration
|
|
else:
|
|
data = next(self.dataset_iter)
|
|
return self.collate_fn(data)
|
|
|
|
|
|
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
|
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
|
|
super(_MapDatasetFetcher, self).__init__(
|
|
dataset, auto_collation, collate_fn, drop_last
|
|
)
|
|
|
|
def fetch(self, possibly_batched_index):
|
|
if self.auto_collation:
|
|
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
|
|
data = self.dataset.__getitems__(possibly_batched_index)
|
|
else:
|
|
data = [self.dataset[idx] for idx in possibly_batched_index]
|
|
else:
|
|
data = self.dataset[possibly_batched_index]
|
|
return self.collate_fn(data)
|