mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix Subset of a Subset not sliceable issue (#59513)
Summary: Dataset can be indexed by a list, but a list can not be indexed by a list. This gives error when slicing a Subset initialised with a Subset, instead of a dataset. Fixed the issue by changing the indices to a Tensor which can be indexed by a list. Fixes https://github.com/pytorch/pytorch/issues/59512 Pull Request resolved: https://github.com/pytorch/pytorch/pull/59513 Reviewed By: zou3519 Differential Revision: D29196891 Pulled By: ejguan fbshipit-source-id: ccde6e474fbcbddd2e9c7c107bc8b5de1307cdb9
This commit is contained in:
parent
08ce5eedf5
commit
7c29ca7f2b
|
|
@ -13,7 +13,7 @@ import itertools
|
|||
import warnings
|
||||
import tempfile
|
||||
from torch import multiprocessing as mp
|
||||
from torch.utils.data import _utils, Dataset, IterableDataset, TensorDataset, DataLoader, ConcatDataset, ChainDataset
|
||||
from torch.utils.data import _utils, Dataset, IterableDataset, TensorDataset, DataLoader, ConcatDataset, ChainDataset, Subset
|
||||
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
|
||||
from torch.utils.data.dataset import random_split
|
||||
from torch._utils import ExceptionWrapper
|
||||
|
|
@ -151,6 +151,35 @@ class TestDatasetRandomSplit(TestCase):
|
|||
b = torch.rand(10)
|
||||
self.assertEqual(a, b)
|
||||
|
||||
def test_slicing_of_subset_of_dataset(self):
|
||||
# Testing slicing a subset initialized with a dataset
|
||||
dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5]))
|
||||
subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4])
|
||||
self.assertEqual(subset_of_dataset[:], dataset[:])
|
||||
self.assertEqual(subset_of_dataset[1:2], dataset[1:2])
|
||||
self.assertEqual(subset_of_dataset[0:-1:2], dataset[0:-1:2])
|
||||
# Testing slicing of subset from random split
|
||||
subset1, subset2 = random_split(dataset, [3, 2])
|
||||
self.assertEqual(subset1[:], dataset[subset1.indices[:]])
|
||||
self.assertEqual(subset1[0:2], dataset[subset1.indices[0:2]])
|
||||
self.assertEqual(subset1[0:-1:2], dataset[subset1.indices[0:-1:2]])
|
||||
|
||||
def test_slicing_of_subset_of_subset(self):
|
||||
# Testing slicing a subset initialized with a subset
|
||||
dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5]))
|
||||
subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4])
|
||||
subset_of_subset = Subset(subset_of_dataset, [0, 1, 2, 3, 4])
|
||||
self.assertEqual(subset_of_subset[:], dataset[:])
|
||||
self.assertEqual(subset_of_subset[0:2], dataset[0:2])
|
||||
self.assertEqual(subset_of_subset[0:-1:2], dataset[0:-1:2])
|
||||
# Testing slicing of subset of subset from random split
|
||||
subset1, subset2 = random_split(dataset, [4, 1])
|
||||
subset_of_subset1, subset_of_subset2 = random_split(subset1, [3, 1])
|
||||
idx = [subset1.indices[i] for i in subset_of_subset1.indices]
|
||||
self.assertEqual(subset_of_subset1[:], dataset[idx[:]])
|
||||
self.assertEqual(subset_of_subset1[0:2], dataset[idx[0:2]])
|
||||
self.assertEqual(subset_of_subset1[0:-1:2], dataset[idx[0:-1:2]])
|
||||
|
||||
|
||||
class CUDACountingDataset(Dataset):
|
||||
def __init__(self, n):
|
||||
|
|
|
|||
|
|
@ -316,6 +316,8 @@ class Subset(Dataset[T_co]):
|
|||
self.indices = indices
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, list):
|
||||
return self.dataset[[self.indices[i] for i in idx]]
|
||||
return self.dataset[self.indices[idx]]
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user