pytorch/torch/utils/data/graph_settings.py
Nicolas Hug 5667c4ea21 Remove default parameter of ShufflerIterDataPipe (#74370)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74370

Closes https://github.com/pytorch/data/issues/298. This PR:

- removes the `default` parameter of `ShufflerIterDataPipe`
- renames `set_shuffle_setting()` into `set_shuffle()`
- let `set_shuffle()` return `self`.

Test Plan: Imported from OSS

Reviewed By: george-qi

Differential Revision: D35073666

Pulled By: NicolasHug

fbshipit-source-id: 9847b037e70f44f36eaf4471f2c12fa8ec2ed73c
(cherry picked from commit b07ab646f308532886e8daddd57e937a53edb153)
2022-03-28 12:47:24 +00:00

37 lines
1.4 KiB
Python

import torch.utils.data.graph
from torch.utils.data.datapipes.iter import Shuffler
def get_all_graph_pipes(graph):
results = set()
for datapipe, sub_graph in graph.items():
results.add(datapipe)
sub_items = get_all_graph_pipes(sub_graph)
for item in sub_items:
results.add(item)
return results
def apply_sharding(datapipe, num_of_instances, instance_id):
graph = torch.utils.data.graph.traverse(datapipe, only_datapipe=True)
all_pipes = get_all_graph_pipes(graph)
already_applied_to = None
for pipe in all_pipes:
if hasattr(pipe, 'is_shardable'):
if pipe.is_shardable():
if hasattr(pipe, 'apply_sharding'):
if already_applied_to is not None:
raise RuntimeError('This implementation of sharding can be only applied once per instance of DataPipeline.',
'Already applied to', already_applied_to, 'while trying to apply to', pipe)
pipe.apply_sharding(num_of_instances, instance_id)
already_applied_to = pipe
def apply_shuffle_settings(datapipe, shuffle):
if shuffle is not None:
graph = torch.utils.data.graph.traverse(datapipe, only_datapipe=True)
all_pipes = get_all_graph_pipes(graph)
for pipe in all_pipes:
if isinstance(pipe, Shuffler):
pipe.set_shuffle(shuffle)