Make datasets in ConcatDataset not need to be sized (#64114)

Summary:
`datasets` needs to be iterable, but also sized because the length is checked. But immediately after it's converted to a list. By changing the order of these 2 lines, it doesn't need to be sized anymore.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64114

Reviewed By: H-Huang

Differential Revision: D30641480

Pulled By: ejguan

fbshipit-source-id: 7e16548c2123afa65b83845f9929271fa07fe1e8
This commit is contained in:
Santiago Castro 2021-09-01 15:18:14 -07:00 committed by Facebook GitHub Bot
parent 535526b95c
commit 69f4401b7b

View File

@ -271,9 +271,8 @@ class ConcatDataset(Dataset[T_co]):
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ConcatDataset, self).__init__()
# Cannot verify that datasets is Sized
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
self.datasets = list(datasets)
assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)