mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Make datasets in ConcatDataset not need to be sized (#64114)
Summary: `datasets` needs to be iterable, but also sized because the length is checked. But immediately after it's converted to a list. By changing the order of these 2 lines, it doesn't need to be sized anymore. Pull Request resolved: https://github.com/pytorch/pytorch/pull/64114 Reviewed By: H-Huang Differential Revision: D30641480 Pulled By: ejguan fbshipit-source-id: 7e16548c2123afa65b83845f9929271fa07fe1e8
This commit is contained in:
parent
535526b95c
commit
69f4401b7b
|
|
@ -271,9 +271,8 @@ class ConcatDataset(Dataset[T_co]):
|
|||
|
||||
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
||||
super(ConcatDataset, self).__init__()
|
||||
# Cannot verify that datasets is Sized
|
||||
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
|
||||
self.datasets = list(datasets)
|
||||
assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
|
||||
for d in self.datasets:
|
||||
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
|
||||
self.cumulative_sizes = self.cumsum(self.datasets)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user