Remove default parameter of ShufflerIterDataPipe (#74370)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74370

Closes https://github.com/pytorch/data/issues/298. This PR:

- removes the `default` parameter of `ShufflerIterDataPipe`
- renames `set_shuffle_setting()` into `set_shuffle()`
- let `set_shuffle()` return `self`.

Test Plan: Imported from OSS

Reviewed By: george-qi

Differential Revision: D35073666

Pulled By: NicolasHug

fbshipit-source-id: 9847b037e70f44f36eaf4471f2c12fa8ec2ed73c
(cherry picked from commit b07ab646f308532886e8daddd57e937a53edb153)
This commit is contained in:
Nicolas Hug 2022-03-28 05:42:22 -07:00 committed by PyTorch MergeBot
parent 1c5a812579
commit 5667c4ea21
3 changed files with 12 additions and 7 deletions

View File

@ -1407,6 +1407,10 @@ class TestFunctionalIterDataPipe(TestCase):
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
len(shuffle_dp_nl)
# Test: deactivate shuffling via set_shuffle
unshuffled_dp = input_ds.shuffle().set_shuffle(False)
self.assertEqual(list(unshuffled_dp), list(input_ds))
def test_zip_iterdatapipe(self):
# Functional Test: raises TypeError when an input is not of type `IterDataPipe`

View File

@ -84,7 +84,6 @@ class ShufflerIterDataPipe(IterDataPipe[T_co]):
def __init__(self,
datapipe: IterDataPipe[T_co],
*,
default: bool = True,
buffer_size: int = 10000,
unbatch_level: int = 0
) -> None:
@ -95,7 +94,7 @@ class ShufflerIterDataPipe(IterDataPipe[T_co]):
else:
self.datapipe = datapipe.unbatch(unbatch_level=unbatch_level)
self.buffer_size = buffer_size
self._shuffle_enabled = default
self._enabled = True
@staticmethod
def buffer_replace(buffer, x):
@ -104,11 +103,12 @@ class ShufflerIterDataPipe(IterDataPipe[T_co]):
buffer[idx] = x
return val
def set_shuffle_settings(self, shuffle=True):
self._shuffle_enabled = shuffle
def set_shuffle(self, shuffle=True):
self._enabled = shuffle
return self
def __iter__(self) -> Iterator[T_co]:
if not self._shuffle_enabled:
if not self._enabled:
for x in self.datapipe:
yield x
else:

View File

@ -1,4 +1,5 @@
import torch.utils.data.graph
from torch.utils.data.datapipes.iter import Shuffler
def get_all_graph_pipes(graph):
@ -31,5 +32,5 @@ def apply_shuffle_settings(datapipe, shuffle):
graph = torch.utils.data.graph.traverse(datapipe, only_datapipe=True)
all_pipes = get_all_graph_pipes(graph)
for pipe in all_pipes:
if hasattr(pipe, 'set_shuffle_settings'):
pipe.set_shuffle_settings(shuffle)
if isinstance(pipe, Shuffler):
pipe.set_shuffle(shuffle)