pytorch/torch/utils/data/graph.py
Kevin Tse 64a526d4af [DataLoader] Replacing traverse function with traverse_datapipes (#85667)
This PR deprecates `traverse` function and replaces it with `traverse_datapipes` instead.

While use `DataLoader`, I realized that it is raising `FutureWarning` even though I am not explicitly using `traverse`. What is happening is that `DataLoader` invokes `traverse(dp, only_datapipe=True)`, and the usage of the keyword causes the `only_datapipe` warning to be raised.

```
/home/ubuntu/miniconda3/lib/python3.8/site-packages/torch/utils/data/graph.py:102: FutureWarning: `only_datapipe` is deprecated from `traverse` function and will be removed after 1.13.
  warnings.warn(msg, FutureWarning)
```

A few things we'd like to do:
1. Deprecate the key word arg `only_datapipe`
2. Change the default behavior from `only_datapipe=False` to `only_datapipe=True` in the future
3. Do not raise a warning when users are using the function correctly

This creates a paradox it is impossible for the users to change their code to match the future default behavior (i.e. call `traverse(dp)` without `only_datapipe`):
  - they cannot do so because the default behavior of `traverse` hasn't changed yet, so they must use `only_datapipe=True`
  - if they use `only_datapipe=True`, eventually the kwarg will go away and cause a runtime error; they also get a `FutureWarning` in the present

IIUC, there doesn't seem to be a way to accomplish those 3 goals without replacing the function with a new one that has a different name; hence, this PR. Let me know if there is a better alternative.

If this looks right, I will send a follow up PR in `TorchData`.

Differential Revision: [D39832183](https://our.internmc.facebook.com/intern/diff/D39832183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85667
Approved by: https://github.com/ejguan
2022-09-27 19:58:15 +00:00

147 lines
5.7 KiB
Python

import io
import pickle
import warnings
from collections.abc import Collection
from typing import Dict, List, Optional, Set, Tuple, Type, Union
from torch.utils.data import IterDataPipe, MapDataPipe
from torch.utils.data._utils.serialization import DILL_AVAILABLE
__all__ = ["traverse", "traverse_dps"]
DataPipe = Union[IterDataPipe, MapDataPipe]
DataPipeGraph = Dict[int, Tuple[DataPipe, "DataPipeGraph"]] # type: ignore[misc]
def _stub_unpickler():
return "STUB"
# TODO(VitalyFedyunin): Make sure it works without dill module installed
def _list_connected_datapipes(scan_obj: DataPipe, only_datapipe: bool, cache: Set[int]) -> List[DataPipe]:
f = io.BytesIO()
p = pickle.Pickler(f) # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is
if DILL_AVAILABLE:
from dill import Pickler as dill_Pickler
d = dill_Pickler(f)
else:
d = None
captured_connections = []
def getstate_hook(ori_state):
state = None
if isinstance(ori_state, dict):
state = {} # type: ignore[assignment]
for k, v in ori_state.items():
if isinstance(v, (IterDataPipe, MapDataPipe, Collection)):
state[k] = v # type: ignore[attr-defined]
elif isinstance(ori_state, (tuple, list)):
state = [] # type: ignore[assignment]
for v in ori_state:
if isinstance(v, (IterDataPipe, MapDataPipe, Collection)):
state.append(v) # type: ignore[attr-defined]
elif isinstance(ori_state, (IterDataPipe, MapDataPipe, Collection)):
state = ori_state # type: ignore[assignment]
return state
def reduce_hook(obj):
if obj == scan_obj or id(obj) in cache:
raise NotImplementedError
else:
captured_connections.append(obj)
# Adding id to remove duplicate DataPipe serialized at the same level
cache.add(id(obj))
return _stub_unpickler, ()
datapipe_classes: Tuple[Type[DataPipe]] = (IterDataPipe, MapDataPipe) # type: ignore[assignment]
try:
for cls in datapipe_classes:
cls.set_reduce_ex_hook(reduce_hook)
if only_datapipe:
cls.set_getstate_hook(getstate_hook)
try:
p.dump(scan_obj)
except (pickle.PickleError, AttributeError, TypeError):
if DILL_AVAILABLE:
d.dump(scan_obj)
else:
raise
finally:
for cls in datapipe_classes:
cls.set_reduce_ex_hook(None)
if only_datapipe:
cls.set_getstate_hook(None)
if DILL_AVAILABLE:
from dill import extend as dill_extend
dill_extend(False) # Undo change to dispatch table
return captured_connections
def traverse_dps(datapipe: DataPipe) -> DataPipeGraph:
r"""
Traverse the DataPipes and their attributes to extract the DataPipe graph.
This only looks into the attribute from each DataPipe that is either a
DataPipe and a Python collection object such as ``list``, ``tuple``,
``set`` and ``dict``.
Args:
datapipe: the end DataPipe of the graph
Returns:
A graph represented as a nested dictionary, where keys are ids of DataPipe instances
and values are tuples of DataPipe instance and the sub-graph
"""
cache: Set[int] = set()
return _traverse_helper(datapipe, only_datapipe=True, cache=cache)
def traverse(datapipe: DataPipe, only_datapipe: Optional[bool] = None) -> DataPipeGraph:
r"""
[Deprecated] Traverse the DataPipes and their attributes to extract the DataPipe graph. When
``only_dataPipe`` is specified as ``True``, it would only look into the attribute
from each DataPipe that is either a DataPipe and a Python collection object such as
``list``, ``tuple``, ``set`` and ``dict``.
Note:
This function is deprecated. Please use `traverse_dps` instead.
Args:
datapipe: the end DataPipe of the graph
only_datapipe: If ``False`` (default), all attributes of each DataPipe are traversed.
This argument is deprecating and will be removed after the next release.
Returns:
A graph represented as a nested dictionary, where keys are ids of DataPipe instances
and values are tuples of DataPipe instance and the sub-graph
"""
msg = "`traverse` function and will be removed after 1.13. " \
"Please use `traverse_dps` instead."
if not only_datapipe:
msg += " And, the behavior will be changed to the equivalent of `only_datapipe=True`."
warnings.warn(msg, FutureWarning)
if only_datapipe is None:
only_datapipe = False
cache: Set[int] = set()
return _traverse_helper(datapipe, only_datapipe, cache)
# Add cache here to prevent infinite recursion on DataPipe
def _traverse_helper(datapipe: DataPipe, only_datapipe: bool, cache: Set[int]) -> DataPipeGraph:
if not isinstance(datapipe, (IterDataPipe, MapDataPipe)):
raise RuntimeError("Expected `IterDataPipe` or `MapDataPipe`, but {} is found".format(type(datapipe)))
dp_id = id(datapipe)
if dp_id in cache:
return {}
cache.add(dp_id)
# Using cache.copy() here is to prevent the same DataPipe pollutes the cache on different paths
items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy())
d: DataPipeGraph = {dp_id: (datapipe, {})}
for item in items:
# Using cache.copy() here is to prevent recursion on a single path rather than global graph
# Single DataPipe can present multiple times in different paths in graph
d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
return d