mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove default parameter of ShufflerIterDataPipe (#74370)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74370 Closes https://github.com/pytorch/data/issues/298. This PR: - removes the `default` parameter of `ShufflerIterDataPipe` - renames `set_shuffle_setting()` into `set_shuffle()` - let `set_shuffle()` return `self`. Test Plan: Imported from OSS Reviewed By: george-qi Differential Revision: D35073666 Pulled By: NicolasHug fbshipit-source-id: 9847b037e70f44f36eaf4471f2c12fa8ec2ed73c (cherry picked from commit b07ab646f308532886e8daddd57e937a53edb153)
This commit is contained in:
parent
1c5a812579
commit
5667c4ea21
|
|
@ -1407,6 +1407,10 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
|
||||
len(shuffle_dp_nl)
|
||||
|
||||
# Test: deactivate shuffling via set_shuffle
|
||||
unshuffled_dp = input_ds.shuffle().set_shuffle(False)
|
||||
self.assertEqual(list(unshuffled_dp), list(input_ds))
|
||||
|
||||
def test_zip_iterdatapipe(self):
|
||||
|
||||
# Functional Test: raises TypeError when an input is not of type `IterDataPipe`
|
||||
|
|
|
|||
|
|
@ -84,7 +84,6 @@ class ShufflerIterDataPipe(IterDataPipe[T_co]):
|
|||
def __init__(self,
|
||||
datapipe: IterDataPipe[T_co],
|
||||
*,
|
||||
default: bool = True,
|
||||
buffer_size: int = 10000,
|
||||
unbatch_level: int = 0
|
||||
) -> None:
|
||||
|
|
@ -95,7 +94,7 @@ class ShufflerIterDataPipe(IterDataPipe[T_co]):
|
|||
else:
|
||||
self.datapipe = datapipe.unbatch(unbatch_level=unbatch_level)
|
||||
self.buffer_size = buffer_size
|
||||
self._shuffle_enabled = default
|
||||
self._enabled = True
|
||||
|
||||
@staticmethod
|
||||
def buffer_replace(buffer, x):
|
||||
|
|
@ -104,11 +103,12 @@ class ShufflerIterDataPipe(IterDataPipe[T_co]):
|
|||
buffer[idx] = x
|
||||
return val
|
||||
|
||||
def set_shuffle_settings(self, shuffle=True):
|
||||
self._shuffle_enabled = shuffle
|
||||
def set_shuffle(self, shuffle=True):
|
||||
self._enabled = shuffle
|
||||
return self
|
||||
|
||||
def __iter__(self) -> Iterator[T_co]:
|
||||
if not self._shuffle_enabled:
|
||||
if not self._enabled:
|
||||
for x in self.datapipe:
|
||||
yield x
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import torch.utils.data.graph
|
||||
from torch.utils.data.datapipes.iter import Shuffler
|
||||
|
||||
|
||||
def get_all_graph_pipes(graph):
|
||||
|
|
@ -31,5 +32,5 @@ def apply_shuffle_settings(datapipe, shuffle):
|
|||
graph = torch.utils.data.graph.traverse(datapipe, only_datapipe=True)
|
||||
all_pipes = get_all_graph_pipes(graph)
|
||||
for pipe in all_pipes:
|
||||
if hasattr(pipe, 'set_shuffle_settings'):
|
||||
pipe.set_shuffle_settings(shuffle)
|
||||
if isinstance(pipe, Shuffler):
|
||||
pipe.set_shuffle(shuffle)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user