mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74370 Closes https://github.com/pytorch/data/issues/298. This PR: - removes the `default` parameter of `ShufflerIterDataPipe` - renames `set_shuffle_setting()` into `set_shuffle()` - let `set_shuffle()` return `self`. Test Plan: Imported from OSS Reviewed By: george-qi Differential Revision: D35073666 Pulled By: NicolasHug fbshipit-source-id: 9847b037e70f44f36eaf4471f2c12fa8ec2ed73c (cherry picked from commit b07ab646f308532886e8daddd57e937a53edb153)
37 lines
1.4 KiB
Python
37 lines
1.4 KiB
Python
import torch.utils.data.graph
|
|
from torch.utils.data.datapipes.iter import Shuffler
|
|
|
|
|
|
def get_all_graph_pipes(graph):
|
|
results = set()
|
|
for datapipe, sub_graph in graph.items():
|
|
results.add(datapipe)
|
|
sub_items = get_all_graph_pipes(sub_graph)
|
|
for item in sub_items:
|
|
results.add(item)
|
|
return results
|
|
|
|
|
|
def apply_sharding(datapipe, num_of_instances, instance_id):
|
|
graph = torch.utils.data.graph.traverse(datapipe, only_datapipe=True)
|
|
all_pipes = get_all_graph_pipes(graph)
|
|
already_applied_to = None
|
|
for pipe in all_pipes:
|
|
if hasattr(pipe, 'is_shardable'):
|
|
if pipe.is_shardable():
|
|
if hasattr(pipe, 'apply_sharding'):
|
|
if already_applied_to is not None:
|
|
raise RuntimeError('This implementation of sharding can be only applied once per instance of DataPipeline.',
|
|
'Already applied to', already_applied_to, 'while trying to apply to', pipe)
|
|
pipe.apply_sharding(num_of_instances, instance_id)
|
|
already_applied_to = pipe
|
|
|
|
|
|
def apply_shuffle_settings(datapipe, shuffle):
|
|
if shuffle is not None:
|
|
graph = torch.utils.data.graph.traverse(datapipe, only_datapipe=True)
|
|
all_pipes = get_all_graph_pipes(graph)
|
|
for pipe in all_pipes:
|
|
if isinstance(pipe, Shuffler):
|
|
pipe.set_shuffle(shuffle)
|