Commit Graph

74 Commits

Author SHA1 Message Date
Yuanyuan Chen
f591bb5056 Remove data_source argument from Sampler (#163134)
`data_source` is declared being removed in PT 2.2 but not.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163134
Approved by: https://github.com/ezyang
2025-09-21 05:44:41 +00:00
dsashidh
a87aea03f7 Update RandomSampler docstring. data_source must be Sized not Dataset (#158857)
Fixes #158631

The docstring said data_source was a Dataset, but RandomSampler only needs something that implements __len__. This updates the docstring to use Sized instead, which matches the actual type used in the constructor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158857
Approved by: https://github.com/divyanshk
2025-09-20 04:05:25 +00:00
Xuehai Pan
5cedc5a0ff [BE][PYFMT] migrate PYFMT for torch/[p-z]*/ to ruff format (#144552)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144552
Approved by: https://github.com/ezyang
2025-08-07 00:09:56 +00:00
Divyansh Khanna
e6d8ed02cb PyTorch Data Sampler benchmark (#156974)
## Motivation
Many PRs optimizing samplers (for eg https://github.com/pytorch/pytorch/pull/147706, https://github.com/pytorch/pytorch/pull/137423) are leveraging an adhoc script for benchmarking samplers. The script and outputs are often copied over in PRs. We want to begin centralizing benchmarks for torch.utils.data components.

## What ?
* This PR adds a new sub-folder in `benchmarks`  for `data`. This is aimed to cover benchmarking scripts for torch.utils.data components like dataloader and sampler.
* Specifically, this PR includes a simple script to time samplers. This is often "copy-pasted" in PRs optimizing samplers. Having it in a centralized location should prevent that, and allow a common standard.

## Output
```
Benchmark Results:
+--------------+-------------+----------------+-----------+-----------+
|   Batch Size | Drop Last   |   Original (s) |   New (s) | Speedup   |
+==============+=============+================+===========+===========+
|            4 | True        |         0.004  |    0.0088 | -119.62%  |
+--------------+-------------+----------------+-----------+-----------+
|            4 | False       |         0.0083 |    0.009  | -9.23%    |
+--------------+-------------+----------------+-----------+-----------+
|            8 | True        |         0.003  |    0.0074 | -147.64%  |
+--------------+-------------+----------------+-----------+-----------+
|            8 | False       |         0.0054 |    0.0075 | -38.72%   |
+--------------+-------------+----------------+-----------+-----------+
|           64 | True        |         0.0021 |    0.0056 | -161.92%  |
+--------------+-------------+----------------+-----------+-----------+
|           64 | False       |         0.0029 |    0.0055 | -92.50%   |
+--------------+-------------+----------------+-----------+-----------+
|          640 | True        |         0.002  |    0.0055 | -168.75%  |
+--------------+-------------+----------------+-----------+-----------+
|          640 | False       |         0.0024 |    0.0062 | -161.35%  |
+--------------+-------------+----------------+-----------+-----------+
|         6400 | True        |         0.0021 |    0.0055 | -160.13%  |
+--------------+-------------+----------------+-----------+-----------+
|         6400 | False       |         0.0021 |    0.0068 | -215.46%  |
+--------------+-------------+----------------+-----------+-----------+
|        64000 | True        |         0.0042 |    0.0065 | -55.29%   |
+--------------+-------------+----------------+-----------+-----------+
|        64000 | False       |         0.0029 |    0.0077 | -169.56%  |
+--------------+-------------+----------------+-----------+-----------+
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156974
Approved by: https://github.com/ramanishsingh
2025-06-27 04:49:43 +00:00
Luca Arnaboldi
c3bb174bb2 SubsetRandomSampler - changed iteration over tensor to iteration over list (#149126)
Digging further the problem at https://github.com/UKPLab/sentence-transformers/pull/3261, it boils down to this expensive loop over a torch tensor. Looping over a list, like in RandomSampler, solves the issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149126
Approved by: https://github.com/divyanshk, https://github.com/cyyever
2025-03-31 04:33:35 +00:00
Aaron Orenstein
2f9d378f7b PEP585 update - torch/utils (#145201)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145201
Approved by: https://github.com/bobrenjc93
2025-01-21 21:04:10 +00:00
xangma
fe8d66d9a6 Faster Faster BatchSampler (#137423)
Builds upon #76951.

Benchmarking code is the same as in #76950.

AMD Ryzen Threadripper PRO 3995WX:
```
  batch_size  drop_last      origin     new  speedup
------------  -----------  --------  ------  ---------
           4  True           0.94    0.5706  64.74%
           4  False          0.9745  0.9468  2.93%
           8  True           0.7423  0.3715  99.82%
           8  False          0.7974  0.5666  40.73%
          64  True           0.5394  0.2085  158.76%
          64  False          0.6083  0.2697  125.51%
         640  True           0.5448  0.1985  174.41%
         640  False          0.7085  0.2308  206.91%
        6400  True           0.5554  0.2028  173.88%
        6400  False          0.7711  0.2109  265.60%
       64000  True           0.556   0.2091  165.82%
       64000  False          0.7803  0.2078  275.58%
```

When `drop_last == True`, it uses `zip` to speed things up.
When `drop_last == False`, it uses `itertools` to speed things up.

`itertools` was the fastest way I could find that deals with the last batch if it is smaller than `batch_size`. I have a pure python method too, but it is slower when `batch_size` is 4 or 8, so I have committed the `itertools` version for now.

Happy to chat further about this change :-) I understand you may not want to introduce the `itertools` package into [sampler.py](https://github.com/pytorch/pytorch/blob/main/torch/utils/data/sampler.py).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137423
Approved by: https://github.com/Skylion007
2024-10-13 09:36:03 +00:00
Xuehai Pan
f1df13f023 [BE][Easy] Fix PYI001: unprefixed-type-param in torch/utils/data/datapipes (#129885)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129885
Approved by: https://github.com/ezyang
2024-07-02 14:56:27 +00:00
Xuehai Pan
7cf0b90e49 [BE] enable UFMT in torch.utils.data (#127705)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127705
Approved by: https://github.com/ezyang
ghstack dependencies: #127706, #127704
2024-06-27 23:16:24 +00:00
Xuehai Pan
f911957573 [BE] sort imports in torch.utils.data (#127704)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127704
Approved by: https://github.com/ezyang
ghstack dependencies: #127706
2024-06-27 23:16:24 +00:00
Xuehai Pan
dcc0093dba [BE][Easy] export explicitly imported public submodules (#127703)
Add top-level submodules `torch.{storage,serialization,functional,amp,overrides,types}`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127703
Approved by: https://github.com/ezyang
2024-06-12 05:52:18 +00:00
Aaron Orenstein
8db9dfa2d7 Flip default value for mypy disallow_untyped_defs [9/11] (#127846)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127846
Approved by: https://github.com/ezyang
ghstack dependencies: #127842, #127843, #127844, #127845
2024-06-08 18:50:06 +00:00
Alexander Kurakin
6f1935b0b5 doc: torch.utils.data.Sampler: __len__ is optional (#125938)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125938
Approved by: https://github.com/andrewkho, https://github.com/xmfan
2024-05-20 22:20:36 +00:00
Xuehai Pan
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
Aryan Gupta
92e7f79609 Doc: Add and Fix docstrings for torch.util.data files (#112817)
Fixes #112635

Fix docstrings for `torch.utils.data` files.

```
Before:
> pydocstyle torch/utils/data/graph.py --count
Before: 5
After: 1

> pydocstyle torch/utils/data/graph_settings.py --count
Before: 8
After: 3

> pydocstyle torch/utils/data/dataloader.py --count
Before: 12
After: 6

> pydocstyle torch/utils/data/dataset.py --count
Before: 28
After: 23

> pydocstyle torch/utils/data/sampler.py --count
Before: 24
After: 19

> pydocstyle torch/utils/data/_utils/signal_handling.py --count
Before: 1
After: 0

> pydocstyle torch/utils/data/_utils/__init__.py --count
Before: 2
After: 0

> pydocstyle torch/utils/data/_utils/collate.py --count
Before: 20
After: 6

> pydocstyle torch/utils/data/_utils/fetch.py --count
Before: 3
After: 0

> pydocstyle torch/utils/data/_utils/pin_memory.py --count
Before: 4
After: 1

> pydocstyle torch/utils/data/datapipes/_decorator.py --count
Before: 19
After: 16

> pydocstyle torch/utils/data/datapipes/_hook_iterator.py --count
Before: 13
After: 0

> pydocstyle torch/utils/data/datapipes/_typing.py --count
Before: 17
After: 4

> pydocstyle torch/utils/data/datapipes/gen_pyi.py --count
Before: 19
After: 4
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112817
Approved by: https://github.com/kit1980
2023-11-07 17:59:56 +00:00
PyTorch MergeBot
3a284dae30 Revert "Do not materialize entire randperm in RandomSampler (#103339)"
This reverts commit d80174e2db.

Reverted https://github.com/pytorch/pytorch/pull/103339 on behalf of https://github.com/kit1980 due to Cause issues on MPS, and also fails without numpy ([comment](https://github.com/pytorch/pytorch/pull/103339#issuecomment-1781705172))
2023-10-26 18:53:14 +00:00
katotaisei
bcda859e34 fix typos (#108006)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108006
Approved by: https://github.com/Skylion007
2023-08-28 19:49:09 +00:00
Aaron Gokaslan
660e8060ad [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-22 23:16:38 +00:00
PyTorch MergeBot
d59a6864fb Revert "[BE]: Update ruff to 0.285 (#107519)"
This reverts commit 88ab3e4322.

Reverted https://github.com/pytorch/pytorch/pull/107519 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR breaks internal tests. @ezyang, can you please hep them get unblocked? It seems like one of the strings was prob accidentally modified ([comment](https://github.com/pytorch/pytorch/pull/107519#issuecomment-1688833480))
2023-08-22 19:53:32 +00:00
Aaron Gokaslan
88ab3e4322 [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-20 01:36:18 +00:00
Justin Chu
4cc1745b13 [BE] f-stringify torch/ and scripts (#105538)
This PR is a follow up on the pyupgrade series to convert more strings to use f-strings using `flynt`.

- https://docs.python.org/3/reference/lexical_analysis.html#f-strings
- https://pypi.org/project/flynt/

Command used:

```
flynt torch/ -ll 120
flynt scripts/ -ll 120
flynt tools/ -ll 120
```

and excluded `collect_env.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105538
Approved by: https://github.com/ezyang, https://github.com/malfet
2023-07-21 19:35:24 +00:00
Avi Verma
d80174e2db Do not materialize entire randperm in RandomSampler (#103339)
In our DDP training workloads, each rank was initializing a `RandomSampler` for a dataset with a length of 3.5 billion items. We noticed that when this sampler was in scope, `gc.collect` calls were taking on the order of seconds to run, which would slow down the entire training iteration. This is because when we call `torch.randperm(n).tolist()`, we create a python list of 3.5 billion items, which massively slows down the periodic mark & sweep garbage collection.

This PR swaps out the `.tolist()` call with a `.numpy()` call and manually calls `.item()` on each element as it is being requested. This has two benefits:

1. The first call to `RandomSampler::__next__` should be about twice as fast, since `.numpy` does not copy the contents of the original tensor
2. The runtime of `gc.collect()` calls no longer scales linearly with the size of the dataset passed to `RandomSampler`

I've attached some `timeit` samples to illustrate the speedups with this Pr:

```
Main (no GC):  51.72115747816861
Main (10 GC calls) 83.61965207383037
PR (no GC) 33.06403830461204
PR (10 GC calls) 33.959467427805066
```

Code
```python
from timeit import timeit

baseline_no_gc = """
import torch

n = int(1e9)
steps = n // 100

x = torch.randperm(n).tolist()
x_iter = iter(x)

for i in range(steps):
    next(x_iter)
"""

baseline_gc = """
import torch
import gc
n = int(1e9)
steps = n // 100
gc_every = steps // 10

x = torch.randperm(n).tolist()
x_iter = iter(x)

for i in range(steps):
    next(x_iter)
    if i % gc_every == 0:
        gc.collect()
"""

numpy_no_gc = """
import torch
n = int(1e9)
steps = n // 100

x = torch.randperm(n).numpy()
x_iter = (i.item() for i in x)

for i in range(steps):
    next(x_iter)
"""

numpy_gc = """
import torch
import gc
n = int(1e9)
steps = n // 100
gc_every = steps // 10

x = torch.randperm(n).numpy()
x_iter = (i.item() for i in x)

for i in range(steps):
    next(x_iter)
    if i % gc_every == 0:
        gc.collect()
"""

if __name__ == "__main__":
    print("Main (no GC): ", timeit(baseline_no_gc, number=1))
    print("Main (10 GC calls)", timeit(baseline_gc, number=1))
    print("PR (no GC)",  timeit(numpy_no_gc, number=1))
    print("PR (10 GC calls)", timeit(numpy_gc, number=1))

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103339
Approved by: https://github.com/kit1980
2023-06-16 19:25:58 +00:00
Ramil Nugmanov
3e18d3958b [DataLoader] Follow-up Fix: TypeVars of Sampler (#100409)
API backward compatibility fixed:
https://github.com/pytorch/pytorch/pull/97338#discussion_r1169164163

Mapped Dataset can accept noninteger indices from custom Samplers.

Fixes #97338

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100409
Approved by: https://github.com/ejguan, https://github.com/NivekT
2023-05-03 17:38:31 +00:00
Ramil Nugmanov
867b07b424 Sampler API described for customization. (#97338)
Explanation with examples of sampler customization added.

* fixed TypeVar
* removed unused init from Sampler class
* added examples for custom sampler and batch sampler
* Distributed sampler typing fixed.
* _InfiniteConstantSampler fixed

Fixes #92268

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97338
Approved by: https://github.com/NivekT
2023-03-28 06:40:38 +00:00
joncrall
4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00
Oliver Sellwood
cc6a51c9f3 added shape checking to WeightedRandomSampler (#78585)
Fixes #78236

An erronously shaped weights vector will result in the following output

```
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~/datarwe/pytorch/torch/utils/data/sampler.py in <module>
      [274](file:///home/oliver/datarwe/pytorch/torch/utils/data/sampler.py?line=273) WeightedRandomSampler([1,2,3], 10)
----> [275](file:///home/oliver/datarwe/pytorch/torch/utils/data/sampler.py?line=274) WeightedRandomSampler([[1,2,3], [4,5,6]], 10)

~/datarwe/pytorch/torch/utils/data/sampler.py in __init__(self, weights, num_samples, replacement, generator)
    [192](file:///home/oliver/datarwe/pytorch/torch/utils/data/sampler.py?line=191)         weights = torch.as_tensor(weights, dtype=torch.double)
    [193](file:///home/oliver/datarwe/pytorch/torch/utils/data/sampler.py?line=192)         if len(weights.shape) != 1:
--> [194](file:///home/oliver/datarwe/pytorch/torch/utils/data/sampler.py?line=193)             raise ValueError("weights should be a 1d sequence but given "
    [195](file:///home/oliver/datarwe/pytorch/torch/utils/data/sampler.py?line=194)                              "weights have shape {}".format(tuple(weights.shape)))
    [196](file:///home/oliver/datarwe/pytorch/torch/utils/data/sampler.py?line=195)

ValueError: weights should be a 1d sequence but given weights have shape (2, 3)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78585
Approved by: https://github.com/NivekT, https://github.com/ejguan
2022-06-02 21:12:14 +00:00
木澄
d22d749a0e faster batch sampler (#76951)
Fixes #76950

Improve the performance of iteration on `BatchSampler` , especially when `batch_size` is big.

Python 3.6.8:
```
  batch_size  drop_last     speedup
------------  -----------   -------
           4  True          -18.07%
           4  False         15.92%
           8  True          9.43%
           8  False         30.90%
          64  True          54.99%
          64  False         49.64%
         640  True          66.26%
         640  False         48.32%
        6400  True          69.06%
        6400  False         45.17%
```

Python 3.8.12:
```
  batch_size  drop_last    speedup
------------  -----------  --------
           4  True         -10.50%
           4  False        -0.78%
           8  True         24.40%
           8  False        10.20%
          64  True         90.96%
          64  False        26.09%
         640  True         112.88%
         640  False        20.09%
        6400  True         111.80%
        6400  False        18.37%

```

Check the issue page for more details of the tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76951
Approved by: https://github.com/ejguan
2022-05-10 18:19:54 +00:00
Erjia Guan
0289ab2cec Fix data-related public API (#368)
Summary:
X-link: https://github.com/pytorch/data/pull/368

This is PR aims to expose the right data-relate API.

There are two more changes made in this PR to convert public api to private api
`check_lambda_fn` -> `_check_lambda_fn`
`deprecation_warning` -> `_deprecation_warning`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76143

Reviewed By: albanD, NivekT

Differential Revision: D35798311

Pulled By: ejguan

fbshipit-source-id: b13fded5c88a533c706702fb2070c918c839dca4
(cherry picked from commit 0b534b829a2e90e1e533951c6d334fdeaa9358b9)
2022-04-21 17:27:05 -07:00
pyhuang97@gmail.com
16a9ffba4b Allow specifying num_samples to RandomSampler even when replacement=False (#71568)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/38032 #39214

Hi, I modified the RandomSampler to satisfy the requirement of https://github.com/pytorch/pytorch/issues/38032. I also added and deleted some test cases in the test/test_dataloader.py to match with the new requirement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71568

Reviewed By: mikaylagawarecki

Differential Revision: D33741776

Pulled By: ejguan

fbshipit-source-id: 2d25f5096b7b36ad9fb6455107182f387cf8ee43
(cherry picked from commit 9c7e1891c2)
2022-01-25 15:34:24 +00:00
Erjia Guan
060e41eafa Forward fix type hint for DataLoader (#66001)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66001

Test Plan: Imported from OSS

Reviewed By: NivekT

Differential Revision: D31340565

Pulled By: ejguan

fbshipit-source-id: d05ae42ebf93f61d781dc5d81ef0222e24f5acb3
2021-10-01 15:48:45 -07:00
Erjia Guan
b777d790ea Convert Sampler back to lazily construction (#63646)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63646

Fixes #63609

Test Plan: Imported from OSS

Reviewed By: NivekT

Differential Revision: D30451774

Pulled By: ejguan

fbshipit-source-id: 550d77494326446d1a42b5da0559e0d384c47413
2021-09-30 07:32:06 -07:00
MY_
dc5ce22a1a A re-open PR: Avoid re-creating the random number generator in RandomSampler (#63026)
Summary:
More details can be found in the old pr: https://github.com/pytorch/pytorch/pull/53085

ejguan  Thanks for your guidance. I tried to reopen this PR following your instructions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63026

Reviewed By: anjali411

Differential Revision: D30224920

Pulled By: ejguan

fbshipit-source-id: 2fa83bd4a2661485e553447fe3e57ce723f2716d
2021-08-16 14:08:37 -07:00
Sam Estep
75024e228c Add lint for unqualified type: ignore (#56290)
Summary:
The other half of https://github.com/pytorch/pytorch/issues/56272.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56290

Test Plan:
CI should pass on the tip of this PR, and we know that the lint works because the following CI runs (before this PR was finished) failed:

- https://github.com/pytorch/pytorch/runs/2384511062
- https://github.com/pytorch/pytorch/actions/runs/765036024

Reviewed By: seemethere

Differential Revision: D27867219

Pulled By: samestep

fbshipit-source-id: e648f07b6822867e70833e23ddafe7fb7eaca235
2021-04-21 08:07:23 -07:00
Zhiyuan Chen
7d4e9bdba1 Add type hint for SequentialSampler (#56374)
Summary:
Add type hint for SequentialSampler

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56374

Reviewed By: heitorschueroff

Differential Revision: D27884528

Pulled By: ejguan

fbshipit-source-id: 68eb900643098565743245c843e76e464f981458
2021-04-20 14:45:52 -07:00
MY_
b22b082cc8 Fixed the error of generator in the RandomSampler. (#52956)
Summary:
In  `__iter__` of the `RandomSampler`, when `self.replacement` is `False` in the original code, `self.generator` is always used in the `torch.randperm` instead of the generator we set.

Fixes https://github.com/pytorch/pytorch/issues/52568

Pull Request resolved: https://github.com/pytorch/pytorch/pull/52956

Reviewed By: mruberry

Differential Revision: D26724303

Pulled By: H-Huang

fbshipit-source-id: 86f2795c76f3548e31181fb077af046078a173cb
2021-03-01 10:05:43 -08:00
Chester Liu
58eb23378f Clean up usage of torch._six partially (#49785)
Summary:
See https://github.com/pytorch/pytorch/issues/42919

Pull Request resolved: https://github.com/pytorch/pytorch/pull/49785

Reviewed By: mruberry

Differential Revision: D25963833

Pulled By: bugra

fbshipit-source-id: 11c90d6b8d3f206c9d0a4d8621b773beb10c6ba2
2021-02-08 13:58:34 -08:00
Samuel Marks
e6779d4357 [*.py] Rename "Arguments:" to "Args:" (#49736)
Summary:
I've written custom parsers and emitters for everything from docstrings to classes and functions. However, I recently came across an issue when I was parsing/generating from the TensorFlow codebase: inconsistent use of `Args:` and `Arguments:` in its docstrings.

```sh
(pytorch#c348fae)$ for name in 'Args:' 'Arguments:'; do
    printf '%-10s %04d\n' "$name" "$(rg -IFtpy --count-matches "$name" | paste -s -d+ -- | bc)"; done
Args:      1095
Arguments: 0336
```

It is easy enough to extend my parsers to support both variants, however it looks like `Arguments:` is wrong anyway, as per:

  - https://google.github.io/styleguide/pyguide.html#doc-function-args @ [`ddccc0f`](https://github.com/google/styleguide/blob/ddccc0f/pyguide.md)

  - https://chromium.googlesource.com/chromiumos/docs/+/master/styleguide/python.md#describing-arguments-in-docstrings @ [`9fc0fc0`](https://chromium.googlesource.com/chromiumos/docs/+/9fc0fc0/styleguide/python.md)

  - https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html @ [`c0ae8e3`](https://github.com/sphinx-contrib/napoleon/blob/c0ae8e3/docs/source/example_google.rst)

Therefore, only `Args:` is valid. This PR replaces them throughout the codebase.

PS: For related PRs, see tensorflow/tensorflow/pull/45420

PPS: The trackbacks automatically appearing below are sending the same changes to other repositories in the [PyTorch](https://github.com/pytorch) organisation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/49736

Reviewed By: albanD

Differential Revision: D25710534

Pulled By: soumith

fbshipit-source-id: 61e8ff01abb433e9f78185c2d1d0cbd7c22c1619
2020-12-28 09:34:47 -08:00
Natalia Gimelshein
74c3dcd1d2 Revert D23725053: [pytorch][PR] change self.generator to generator
Test Plan: revert-hammer

Differential Revision:
D23725053 (a011b86115)

Original commit changeset: 89706313013d

fbshipit-source-id: 035214f0d4298d29a52f8032d364b52dfd956fe8
2020-09-17 09:42:37 -07:00
Fang Zhang
a011b86115 change self.generator to generator (#44461)
Summary:
bug fix

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44461

Reviewed By: mruberry

Differential Revision: D23725053

Pulled By: ngimel

fbshipit-source-id: 89706313013d9eae96aaaf144924867457efd2c0
2020-09-16 11:32:17 -07:00
Daiming Yang
ad7133d3c1 Patch for #40026 RandomSampler generates samples one at a time when replacement=True (#41682)
Summary:
Fix https://github.com/pytorch/pytorch/issues/32530
Fix/Patch https://github.com/pytorch/pytorch/pull/40026

Resubmit this patch and fix the type error.

Force the input type to `manual_seed()` in `sampler.py` to be `int`.

ezyang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41682

Reviewed By: izdeby

Differential Revision: D22665477

Pulled By: ezyang

fbshipit-source-id: 1725c8aa742c31e74321f20448f4b6a392afb38d
2020-07-22 13:45:09 -07:00
Shen Li
86590f226e Revert D22519869: [pytorch][PR] RandomSampler generates samples one at a time when replacement=True
Test Plan: revert-hammer

Differential Revision:
D22519869 (09647e1287)

Original commit changeset: be6585002586

fbshipit-source-id: 31ca5ceb24dd0b291f46f427a6f30f1037252a5d
2020-07-16 12:59:10 -07:00
Daiming Yang
09647e1287 RandomSampler generates samples one at a time when replacement=True (#40026)
Summary:
Fix https://github.com/pytorch/pytorch/issues/32530

I used the next() function to generate samples one at a time. To compensate replacement=False, I added a variable called "sample_list" to RandomSampler for random permutation.

cc SsnL

Pull Request resolved: https://github.com/pytorch/pytorch/pull/40026

Reviewed By: zhangguanheng66

Differential Revision: D22519869

Pulled By: ezyang

fbshipit-source-id: be65850025864d659a713b3bc461b25d6d0048a2
2020-07-16 11:42:32 -07:00
Wojciech Baranowski
0e09511af9 type annotations for dataloader, dataset, sampler (#39392)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/38913

Pull Request resolved: https://github.com/pytorch/pytorch/pull/39392

Reviewed By: anjali411

Differential Revision: D22102489

Pulled By: zou3519

fbshipit-source-id: acb68d9521145f0b047214d62b5bdc5a0d1b9be4
2020-07-07 07:16:18 -07:00
ShawnZhong
c8c53c802e Add generator= kwarg for DataLoader & random samplers (#39737)
Summary:
Fix https://github.com/pytorch/pytorch/issues/39572

Add `generator=` kwarg for DataLoader & random samplers

cc: SsnL, deeppatel4557, albanD, mitar
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39737

Differential Revision: D22019132

Pulled By: albanD

fbshipit-source-id: 835e08b86c5396bc0b0e41057661306b15394d6e
2020-06-15 07:01:20 -07:00
Hong Xu
283a3ff16d The exception raised when RandomSampler.replacement is non-boolean should be TypeError (#36547)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36547

Differential Revision: D21818752

Pulled By: ezyang

fbshipit-source-id: 7502a24a0df134c44ac72959ba992777c873f8e9
2020-06-02 06:54:02 -07:00
SsnL
b5868b2833 Relax sampler check in BatchSampler (#38403)
Summary:
Since the check was added in https://github.com/pytorch/pytorch/pull/6249, one can not pass an iterable as a sampler to the data loader anymore, which was a very handy feature (e.g., https://github.com/pytorch/pytorch/issues/1337). I think the check should be removed for two-fold reasons:
1. It is too strict. There is no reason that it should not be a general iterable.
2. It is inconsistent. In `DataLoader` (the main place where people use samplers), you can pass a general iterable as `batch_sampler` but not `sampler` due to this check.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38403

Differential Revision: D21555958

Pulled By: soumith

fbshipit-source-id: c7267bb99a31edd8f2750689205d6edc5dab5cff
2020-05-13 22:24:29 -07:00
vfdev
c6e0360812 Minor change of docstring example of WeightedRandomSampler (#30846)
Summary:
Previous example
```python
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
        [0, 0, 0, 1, 0]
```
may seem misleading according to provided weights.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30846

Differential Revision: D19697367

Pulled By: ezyang

fbshipit-source-id: 3d6e3cd0cecb5272a368707ba35bc7acdbd82c30
2020-02-12 07:46:39 -08:00
Tongzhou Wang
058beae411 Add IterableDataset (#19228)
Summary:
This is a modified version of https://github.com/pytorch/pytorch/pull/14705 since commit structure for that PR is quite messy.

1. Add `IterableDataset`.
3. So we have 2 data loader mods: `Iterable` and `Map`.

    1. `Iterable` if the `dataset` is an instance of `IterableDataset`
    2. `Map` o.w.

3. Add better support for non-batch loading (i.e., `batch_size=None` and `batch_sampler=None`). This is useful in doing things like bulk loading.
3. Refactor `DataLoaderIter` into two classes, `_SingleProcessDataLoaderIter` and `_MultiProcessingDataLoaderIter`. Rename some methods to be more generic, e.g., `get_batch` -> `get_data`.
4. Add `torch.utils.data.get_worker_info` which returns worker information in a worker proc (e.g., worker id, dataset obj copy, etc.) and can be used in `IterableDataset.__iter__` and `worker_init_fn` to do per-worker configuration.
5. Add `ChainDataset`, which is the analog of `ConcatDataset` for `IterableDataset`.
7. Import torch.utils.data in `torch/__init__.py`
9. data loader examples and documentations
10. Use `get_worker_info` to detect whether we are in a worker process in `default_collate`

Closes https://github.com/pytorch/pytorch/issues/17909, https://github.com/pytorch/pytorch/issues/18096, https://github.com/pytorch/pytorch/issues/19946, and some of https://github.com/pytorch/pytorch/issues/13023
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19228

Reviewed By: bddppq

Differential Revision: D15058152

fbshipit-source-id: 9e081a901a071d7e4502b88054a34b450ab5ddde
2019-06-20 20:12:44 -07:00
Soumith Chintala
6480d3f140 Revert D15511921: [pytorch][PR] BatchSampler now uses list.clear() instead of creating new objects
Differential Revision:
D15511921

Original commit changeset: e943d21e75e1

fbshipit-source-id: 933b7ef74c7a530f0a2cc087c8ee6f0455cf9239
2019-05-27 10:51:24 -07:00
Tongzhou Wang
482ae8e6b2 BatchSampler now uses list.clear() instead of creating new objects
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20976

Differential Revision: D15511921

Pulled By: soumith

fbshipit-source-id: e943d21e75e19f9154a0570f3188cc3ce174083e
2019-05-26 23:45:26 -07:00