mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "Add support to traverse all python collection objects (#84079)"
This reverts commit e0f0c8e7b9.
Reverted https://github.com/pytorch/pytorch/pull/84079 on behalf of https://github.com/weiwangmeta due to Diff reverted internally
This commit is contained in:
parent
0ac2986d33
commit
d50aa517b5
|
|
@ -1,12 +1,11 @@
|
|||
import io
|
||||
import pickle
|
||||
|
||||
from collections.abc import Collection
|
||||
from typing import Dict, List, Set, Tuple, Type, Union
|
||||
|
||||
from torch.utils.data import IterDataPipe, MapDataPipe
|
||||
from torch.utils.data._utils.serialization import DILL_AVAILABLE
|
||||
|
||||
from typing import Dict, List, Set, Tuple, Type, Union
|
||||
|
||||
__all__ = ["traverse"]
|
||||
|
||||
DataPipe = Union[IterDataPipe, MapDataPipe]
|
||||
|
|
@ -37,7 +36,7 @@ def _list_connected_datapipes(scan_obj: DataPipe, only_datapipe: bool, cache: Se
|
|||
def getstate_hook(obj):
|
||||
state = {}
|
||||
for k, v in obj.__dict__.items():
|
||||
if isinstance(v, (IterDataPipe, MapDataPipe, Collection)):
|
||||
if isinstance(v, (IterDataPipe, MapDataPipe, tuple)):
|
||||
state[k] = v
|
||||
return state
|
||||
|
||||
|
|
@ -75,19 +74,6 @@ def _list_connected_datapipes(scan_obj: DataPipe, only_datapipe: bool, cache: Se
|
|||
|
||||
|
||||
def traverse(datapipe: DataPipe, only_datapipe: bool = False) -> DataPipeGraph:
|
||||
r"""
|
||||
Traverse the DataPipes and their attributes to extract the DataPipe graph. When
|
||||
``only_dataPipe`` is specified as ``True``, it would only look into the attribute
|
||||
from each DataPipe that is either a DataPipe and a Python collection object such as
|
||||
``list``, ``tuple``, ``set`` and ``dict``.
|
||||
|
||||
Args:
|
||||
datapipe: the end DataPipe of the graph
|
||||
only_datapipe: If ``False`` (default), all attributes of each DataPipe are traversed
|
||||
Returns:
|
||||
A graph represented as a nested dictionary, where keys are ids of DataPipe instances
|
||||
and values are tuples of DataPipe instance and the sub-graph
|
||||
"""
|
||||
cache: Set[int] = set()
|
||||
return _traverse_helper(datapipe, only_datapipe, cache)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user