Fix Subset of a Subset not sliceable issue (#59513)

Summary:
Dataset can be indexed by a list, but a list can not be indexed by a list. This gives error when slicing a Subset initialised with a Subset, instead of a dataset.

Fixed the issue by changing the indices to a Tensor which can be indexed by a list.

Fixes https://github.com/pytorch/pytorch/issues/59512

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59513

Reviewed By: zou3519

Differential Revision: D29196891

Pulled By: ejguan

fbshipit-source-id: ccde6e474fbcbddd2e9c7c107bc8b5de1307cdb9
This commit is contained in:
TJ-coding 2021-06-18 07:06:17 -07:00 committed by Facebook GitHub Bot
parent 08ce5eedf5
commit 7c29ca7f2b
2 changed files with 32 additions and 1 deletions

View File

@ -13,7 +13,7 @@ import itertools
import warnings
import tempfile
from torch import multiprocessing as mp
from torch.utils.data import _utils, Dataset, IterableDataset, TensorDataset, DataLoader, ConcatDataset, ChainDataset
from torch.utils.data import _utils, Dataset, IterableDataset, TensorDataset, DataLoader, ConcatDataset, ChainDataset, Subset
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from torch.utils.data.dataset import random_split
from torch._utils import ExceptionWrapper
@ -151,6 +151,35 @@ class TestDatasetRandomSplit(TestCase):
b = torch.rand(10)
self.assertEqual(a, b)
def test_slicing_of_subset_of_dataset(self):
# Testing slicing a subset initialized with a dataset
dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5]))
subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4])
self.assertEqual(subset_of_dataset[:], dataset[:])
self.assertEqual(subset_of_dataset[1:2], dataset[1:2])
self.assertEqual(subset_of_dataset[0:-1:2], dataset[0:-1:2])
# Testing slicing of subset from random split
subset1, subset2 = random_split(dataset, [3, 2])
self.assertEqual(subset1[:], dataset[subset1.indices[:]])
self.assertEqual(subset1[0:2], dataset[subset1.indices[0:2]])
self.assertEqual(subset1[0:-1:2], dataset[subset1.indices[0:-1:2]])
def test_slicing_of_subset_of_subset(self):
# Testing slicing a subset initialized with a subset
dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5]))
subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4])
subset_of_subset = Subset(subset_of_dataset, [0, 1, 2, 3, 4])
self.assertEqual(subset_of_subset[:], dataset[:])
self.assertEqual(subset_of_subset[0:2], dataset[0:2])
self.assertEqual(subset_of_subset[0:-1:2], dataset[0:-1:2])
# Testing slicing of subset of subset from random split
subset1, subset2 = random_split(dataset, [4, 1])
subset_of_subset1, subset_of_subset2 = random_split(subset1, [3, 1])
idx = [subset1.indices[i] for i in subset_of_subset1.indices]
self.assertEqual(subset_of_subset1[:], dataset[idx[:]])
self.assertEqual(subset_of_subset1[0:2], dataset[idx[0:2]])
self.assertEqual(subset_of_subset1[0:-1:2], dataset[idx[0:-1:2]])
class CUDACountingDataset(Dataset):
def __init__(self, n):

View File

@ -316,6 +316,8 @@ class Subset(Dataset[T_co]):
self.indices = indices
def __getitem__(self, idx):
if isinstance(idx, list):
return self.dataset[[self.indices[i] for i in idx]]
return self.dataset[self.indices[idx]]
def __len__(self):