Fixes#96975
Changes:
- Make sure custom ShardingDataPipe with `apply_sharding` can be used by `DataLoader`
- Allow the `apply_sharding` function without the last argument of `sharding_group`
- Make `DataLoader` not relying on `sharding_group`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97287
Approved by: https://github.com/NivekT
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
Applies some more harmless pyupgrades. This one gets rid of deprecated aliases in unit_tests and more upgrades yield for loops into yield from generators which are more performance and propagates more information / exceptions from original generator. This is the modern recommended way of forwarding generators.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94309
Approved by: https://github.com/albanD
Move `ShardingFilterIterDataPipe` into a dedicated file.
Also, propose to have a dedicated parent class (`_ShardingIterDataPipe`) for sharding data pipe, as this seems more like a "system/engine-level" datapipe that gives strong hints to RS on how to execute, and needs first-class citizen treatment in RS (compared with other "user-level" datapipe that are mostly composable `Callable[[Iterable], Iterable]`. So we don't need to based on whether `is_shardable` and `apply_sharding` are presented in DataPipe in `graph_settings.py`. But open to other discussions.
Open question: Should
[ShardingRoundRobinDispatcherIterDataPipe](01fc762003/torchdata/datapipes/iter/util/sharding.py (L16-L17)) also be considered as a `_ShardingIterDataPipe`? (e.g. this sharding is executed by replicating (the metadata), while `ShardingRoundRobinDispatcherIterDataPipe` hints too expensive to replicate so requires round robin data exchange/dispatch).
Differential Revision: D43014692
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94095
Approved by: https://github.com/ejguan, https://github.com/NivekT
Changes:
- Allow multiple `sharding_filter` in the pipeline as long as they are not on the same branch
- [x] Add test
Example:
```mermaid
graph TD;
DP1-->sharding_filter_1;
sharding_filter_1-->DP3;
DP2-->sharding_filter_2;
sharding_filter_2-->DP4;
DP3-->DP4;
DP4-->output;
```
In order to properly shard `DP1` and `DP2`, we should allow multiple `sharding_filter`s
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90769
Approved by: https://github.com/NivekT
This PR deprecates `traverse` function and replaces it with `traverse_datapipes` instead.
While use `DataLoader`, I realized that it is raising `FutureWarning` even though I am not explicitly using `traverse`. What is happening is that `DataLoader` invokes `traverse(dp, only_datapipe=True)`, and the usage of the keyword causes the `only_datapipe` warning to be raised.
```
/home/ubuntu/miniconda3/lib/python3.8/site-packages/torch/utils/data/graph.py:102: FutureWarning: `only_datapipe` is deprecated from `traverse` function and will be removed after 1.13.
warnings.warn(msg, FutureWarning)
```
A few things we'd like to do:
1. Deprecate the key word arg `only_datapipe`
2. Change the default behavior from `only_datapipe=False` to `only_datapipe=True` in the future
3. Do not raise a warning when users are using the function correctly
This creates a paradox it is impossible for the users to change their code to match the future default behavior (i.e. call `traverse(dp)` without `only_datapipe`):
- they cannot do so because the default behavior of `traverse` hasn't changed yet, so they must use `only_datapipe=True`
- if they use `only_datapipe=True`, eventually the kwarg will go away and cause a runtime error; they also get a `FutureWarning` in the present
IIUC, there doesn't seem to be a way to accomplish those 3 goals without replacing the function with a new one that has a different name; hence, this PR. Let me know if there is a better alternative.
If this looks right, I will send a follow up PR in `TorchData`.
Differential Revision: [D39832183](https://our.internmc.facebook.com/intern/diff/D39832183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85667
Approved by: https://github.com/ejguan
This PR requires PR is landed: https://github.com/pytorch/pytorch/pull/83202
## changes
- For `apply_shuffle_setting` and `apply_shuffle_seed`, it makes sure it will apply shuffle setting to each of DataPipe that contains a method called `set_shuffle` or `set_seed`.
- Change the API from `apply_shuffle_seed` to `apply_random_seed`.
- Fix a bug that `apply_shuffle_seed` only accepts DataPipe that is hashable. After the PR, this function uses `id` to prevent seeding the same DataPipe multiple times per epoch.
- Fix another bug from `shuffler` that `reset` with `_enable=False` would also reset `_seed`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83741
Approved by: https://github.com/NivekT
As @pmeier [points out](https://github.com/pytorch/pytorch/pull/80267#discussion_r958423241), #80267 introduces a bug where an exception is thrown when a built-in function (or a function implemented in C) is used with `.map` because `inspect.signature(fn)` cannot find the function's signature.
This PR skips over a function when its signature cannot be found. I believe this case is rare, and if the `fn` is truly incompatible with the usage of `input_col`/`output_col`, an exception will be raised at run time such that users will be able to examine what is wrong.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84279
Approved by: https://github.com/pmeier, https://github.com/janeyx99
Fixes: https://github.com/pytorch/data/issues/718
This is an alternative PR against https://github.com/pytorch/pytorch/pull/82974
This PR would change the behavior for both types to the same behavior as `IterDataPipe.shuffle`
- Lazily generating seed per iteration
- Each iterators has a new seed
- Convert `MapDataPipe.shuffle` to an `IterDataPipe`
## BC-breaking Note:
This PR changes the return type of `MapDataPipe.shuffle` from a `MapDataPipe` to a `IterDataPipe`.
### 1. 12
Output as `MapDataPipe`
```
>>> from torch.utils.data import IterDataPipe, MapDataPipe
>>> from torch.utils.data.datapipes.map import SequenceWrapper
>>> dp = SequenceWrapper(list(range(10))).shuffle()
>>> isinstance(dp, MapDataPipe)
True
>>> isinstance(dp, IterDataPipe)
False
```
### This PR:
Output as `IterDataPipe`
```
>>> from torch.utils.data import IterDataPipe, MapDataPipe
>>> from torch.utils.data.datapipes.map import SequenceWrapper
>>> dp = SequenceWrapper(list(range(10))).shuffle()
>>> isinstance(dp, MapDataPipe)
False
>>> isinstance(dp, IterDataPipe)
True
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83202
Approved by: https://github.com/NivekT
Fixes: https://github.com/pytorch/data/issues/718
This is an alternative PR against https://github.com/pytorch/pytorch/pull/82974
This PR would change the behavior for both types to the same behavior as `IterDataPipe.shuffle`
- Lazily generating seed per iteration
- Each iterators has a new seed
- Convert `MapDataPipe.shuffle` to an `IterDataPipe`
## BC-breaking Note:
This PR changes the return type of `MapDataPipe.shuffle` from a `MapDataPipe` to a `IterDataPipe`.
### 1. 12
Output as `MapDataPipe`
```
>>> from torch.utils.data import IterDataPipe, MapDataPipe
>>> from torch.utils.data.datapipes.map import SequenceWrapper
>>> dp = SequenceWrapper(list(range(10))).shuffle()
>>> isinstance(dp, MapDataPipe)
True
>>> isinstance(dp, IterDataPipe)
False
```
### This PR:
Output as `IterDataPipe`
```
>>> from torch.utils.data import IterDataPipe, MapDataPipe
>>> from torch.utils.data.datapipes.map import SequenceWrapper
>>> dp = SequenceWrapper(list(range(10))).shuffle()
>>> isinstance(dp, MapDataPipe)
False
>>> isinstance(dp, IterDataPipe)
True
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83202
Approved by: https://github.com/NivekT
This PR changes the behavior of `IterDataPipe` to always invoke `reset` for the state of `NotStarted`. The main reason is we normally put lazy initialization code into `reset` function. Even for the state of `NotStarted`, we should invoke `reset` to initialize those lazy variables. Otherwise, we have to manually determine if the state is `NotStarted` or `Iterating` in `__iter__` function and only manually invoke `reset` in the state of `NotStarted`.
This PR also makes `Shuffler` is able to serialize with `buffer` and `rng_state`.
The following part is removed:
~I am also add `_snapshot_state` into serialization state and during `__setstate__` only change the state to `Restored` if the original state is `Iterating`. Especially, for the case of deserializing/serializing `NotStarted` DataPipe (multiprocessing), we would invoke `set_seed` for `Shuffler`. We need the `DataPipe` remains as `NotStarted` to properly `reset`.~
I am listing all the expected behavior state transition below:
- Initial state: `NotStarted`
- `iter` -> Call `reset` and change the state to `Iterating`
- serialize/deserialize -> Keep the state as `NotStarted` (will `reset` if `iter` is called afterwards)
- Initial state: `Iterating`
- `iter` -> Call `reset` and keep the state to `Iterating`
- serialize/deserialize -> Change the state as `Restored`
- Initial state: `Restored`
- `iter` -> Only change the state to `Iterating`
- serialize/deserialize -> Not allowed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83535
Approved by: https://github.com/NivekT
Fixes https://github.com/pytorch/data/issues/708
The following code snippet used to fail, now it has been added as a test case:
```python
dp1 = dp.map.SequenceWrapper(range(10))
shuffle_dp1 = dp1.shuffle()
dp2 = dp.map.SequenceWrapper(range(10))
shuffle_dp2 = dp2.shuffle()
zip_dp = shuffle_dp1.zip(shuffle_dp2)
list(zip_dp) # This used to fail
```
The issue was that `ShufflerMapDataPipe` raises a `KeyError` when an out of bound index is passed into it, but that was not handled by `zip_dp`'s `__getitem__` which only handled `IndexError`. With this change, it handles both.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82666
Approved by: https://github.com/ejguan
This mostly completes the "poor man's snapshotting" implementation (named "simple snapshotting"). This is the most basic version of snapshotting but it should work for all DataPipes. I will be adding more efficient implementation for different types of DataPipes in future PRs.
### Implementation
The general idea of the simple snapshot is that we will:
1. Create a new iterator
2. Move that iterator forward by `n_iterations`
3. Save that as the `_fast_forward_iterator` of the DataPipe
4. The next time `iter` is called on the DataPipe, use the `_fast_forward_iterator`
### Usage
As of this implementation, the usage will something like:
```python
rng = torch.Generator()
initial_rng_state = rng.get_state()
datapipe: IterDataPipe = ...
# Some usage of the DataPipe, here maybe yielding the first 5 values
n_iter = 5
it = iter(datapipe)
for _ in range(n_iter):
next(it)
serialized_graph = pickle.dumps(datapipe)
# The serialized object has most of the sufficient information for simple snapshot (except for initial RNG state)
# It can be deserialized at a later point in time or by a different process
deserialized_graph = pickle.loads(serialized_graph)
# I think `DataLoader2` or `ReadingService` should store `initial_rng_state` that can be saved by the API that we later use
rng_for_deserialized = torch.Generator()
rng_for_deserialized.set_state(initial_rng_state)
n_iterations = deserialized_graph._number_of_samples_yielded
_simple_snapshot_graph(deserialized_graph, n_iterations, rng=rng_for_deserialized)
# The while DataPipe graph should have the same state as before serialization, such that:
self.assertEqual(list(it), list(deserialized_graph)) # True
```
### Next Steps
If this looks acceptable, the next step is I will modify `DataLoader2`'s prototype ReadingService (the one with queues) to remember things like `initial_rng_state` and to have methods `save_snapshot` that will return the `(serialized graph, initial_rng)` and `restore_snapshot`. This should work for single worker data loading.
Note that, in the long term, `initial_rng_state` may not be necessary if we are able to directly save/restore the buffer and RNG state of `Shuffler` (that is work in progress). However, `initial_rng_state` and simple snapshot is still a good fall-back option for some edge cases where the buffer can't be stored.
Differential Revision: [D37943406](https://our.internmc.facebook.com/intern/diff/D37943406)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79479
Approved by: https://github.com/ejguan
I went through most of the warnings and exceptions raised in our tests to find these issues.
Changes:
1. In testing, `self.assertEquals` is deprecated, converting to `self.assertEqual` to get rid of the warning
2. Small changes for cleanliness and get rid of warnings (no actual change to result)
3. Correct `is_every_instance_exhausted` logic for `_Forker`
4. Catch `RunTimeError` raised by invalidated iterator during clean up
5. Check if attribute `parent_stream` exists before trying to access it
Differential Revision: [D38020122](https://our.internmc.facebook.com/intern/diff/D38020122)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81833
Approved by: https://github.com/ejguan
Summary:
This Diff removes the requirement for `traverse` function that `DataPipe` needs to be hash-able. `traverse` function now is using `id` of `DataPipe` instance rather than `DataPipe` itself as the key for both `cache` and graph.
But, it requires the changes of type of `DataPipeGraph` from `Dict[DataPipe, "DataPipeGraph"]` to `Dict[int, Tuple[DataPipe, "DataPipeGraph"]]`.
Differential Revision: D37354153
Ref PR in TorchData: https://github.com/pytorch/data/pull/559
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80509
Approved by: https://github.com/VitalyFedyunin
This PR adds an attribute and logic to count the number of successful yields from `IterDataPipe`. This information can be useful to fast-forward a DataPipe (or the entire graph) back to a certain state.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79657
Approved by: https://github.com/VitalyFedyunin
Summary:
X-link: https://github.com/pytorch/data/pull/547
Fixes https://github.com/pytorch/data/issues/538
- Improve the validation function to raise warning about unpickable function when either lambda or local function is provided to DataPipe.
- The inner function from functools.partial object is extracted as well for validation
- Mimic the behavior of pickle module for local lambda function: It would only raise Error for the local function rather than lambda function. So, we will raise warning about local function not lambda function.
```py
>>> import pickle
>>> def fn():
... lf = lambda x: x
... pickle.dumps(lf)
>>> pickle.dumps(fn)
AttributeError: Can't pickle local object 'fn.<locals>.<lambda>'
```
This Diff also fixes the Error introduced by https://github.com/pytorch/pytorch/pull/79344
Test Plan:
CI on PyTorch and TorchData
Manually validated the tests from TorchVision
Differential Revision: D37417556
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80232
Approved by: https://github.com/NivekT
Fixes https://github.com/pytorch/data/issues/538
- Improve the validation function to raise warning about unpickable function when either lambda or local function is provided to `DataPipe`.
- The inner function from `functools.partial` object is extracted as well for validation
- Mimic the behavior of `pickle` module for local lambda function: It would only raise Error for the local function rather than `lambda` function. So, we will raise warning about local function not lambda function.
```py
>>> import pickle
>>> def fn():
... lf = lambda x: x
... pickle.dumps(lf)
>>> pickle.dumps(fn)
AttributeError: Can't pickle local object 'fn.<locals>.<lambda>'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80140
Approved by: https://github.com/VitalyFedyunin, https://github.com/NivekT
Fixes https://github.com/pytorch/data/issues/426
This PR introduces two main changes:
- It ensures the `ShufflerDataPipe` would share the same seed across distributed processes.
- Users can reset `shuffle` for persistent workers per epoch.
Detail:
- `shared_seed` is shared across distributed and worker processes. It will seed a `shared_rng` to provide seeds to each `ShufflerDataPipe` in the pipeline
- `worker_loop` now accepts a new argument of `shared_seed` to accept this shared seed.
- The `shared_seed` is attached to `_ResumeIteration` for resetting seed per epoch for `persistent worker`
- I choose not to touch `base_seed` simply for BC issue
I used this [script](https://gist.github.com/ejguan/d88f75fa822cb696ab1bc5bc25844f47) to test the result with `world_size=4`. Please check the result in: https://gist.github.com/ejguan/6ee2d2de12ca57f9eb4b97ef5a0e300b
You can see there isn't any duplicated/missing element for each epoch. And, with the same seed, the order of data remains the same across epochs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78765
Approved by: https://github.com/VitalyFedyunin
Fixes https://github.com/pytorch/data/issues/426
This PR introduces two main changes:
- It ensures the `ShufflerDataPipe` would share the same seed across distributed processes.
- Users can reset `shuffle` for persistent workers per epoch.
Detail:
- `shared_seed` is shared across distributed and worker processes. It will seed a `shared_rng` to provide seeds to each `ShufflerDataPipe` in the pipeline
- `worker_loop` now accepts a new argument of `shared_seed` to accept this shared seed.
- The `shared_seed` is attached to `_ResumeIteration` for resetting seed per epoch for `persistent worker`
- I choose not to touch `base_seed` simply for BC issue
I used this [script](https://gist.github.com/ejguan/d88f75fa822cb696ab1bc5bc25844f47) to test the result with `world_size=4`. Please check the result in: https://gist.github.com/ejguan/6ee2d2de12ca57f9eb4b97ef5a0e300b
You can see there isn't any duplicated/missing element for each epoch. And, with the same seed, the order of data remains the same across epochs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78765
Approved by: https://github.com/VitalyFedyunin
This is the first PR to make DataPipe deterministic.
Users should be able to use `torch.manual_seed(seed)` to control the shuffle order for the following cases:
- Directly over `DataPipe`
- For single-process DataLoader
- Multiprocessing DataLoader
Unfortunately, for distributed training, users have to run `apply_shuffle_seed` manually to make sure all distributed processes having the same order of shuffle.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77741
Approved by: https://github.com/VitalyFedyunin, https://github.com/NivekT
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76384
OSS issue discussion: https://github.com/pytorch/data/issues/346
This diff updates `mux` and `mux_longest` data pipe.
`mux`: Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``). As in, one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration, and so on. It ends when the shortest input DataPipe is exhausted.
`mux` example:
```
>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
>>> list(dp1.mux(dp2, dp3))
[0, 10, 20, 1, 11, 21, 2, 12, 22]
```
Test Plan:
buck test mode/opt //caffe2/test:datapipe
https://www.internalfb.com/intern/testinfra/testrun/4785074706282345
Differential Revision: D36017945
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77145
Approved by: https://github.com/NivekT, https://github.com/ejguan
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76384
OSS issue discussion: https://github.com/pytorch/data/issues/346
This diff updates `mux` and `mux_longest` data pipe.
`mux`: Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``). As in, one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration, and so on. It ends when the shortest input DataPipe is exhausted.
`mux` example:
```
>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
>>> list(dp1.mux(dp2, dp3))
[0, 10, 20, 1, 11, 21, 2, 12, 22]
```
Test Plan:
buck test mode/dev //pytorch/data/test:tests -- --exact 'pytorch/data/test:tests - test_mux_longest_iterdatapipe (test_datapipe.TestDataPipe)'
https://www.internalfb.com/intern/testinfra/testrun/3096224791148107
Reviewed By: ejguan
Differential Revision: D35799965
fbshipit-source-id: 320e71a342ec27e6e9200624aad42f4b99f97c3a
(cherry picked from commit 741ed595275df6c05026ed6f0e78d7052328fb7d)