Commit Graph

579 Commits

Author SHA1 Message Date
Ramil Nugmanov
91eeb77260 StackDataset batched sampling (#110694)
Optimization of loading minibatches

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110694
Approved by: https://github.com/ejguan
2023-10-10 22:05:51 +00:00
Joel Schlosser
43ea782af3 Multiprocessing support for NT (#110292)
Fixes #110161

Allows NTs to be used in DataLoaders with `num_workers > 1`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110292
Approved by: https://github.com/cpuhrsch, https://github.com/albanD
2023-10-10 21:58:19 +00:00
PyTorch MergeBot
dac895c10a Revert "Multiprocessing support for NT (#110292)"
This reverts commit f17fe89e14.

Reverted https://github.com/pytorch/pytorch/pull/110292 on behalf of https://github.com/kit1980 due to Causes CUDA memory leaks ([comment](https://github.com/pytorch/pytorch/pull/110292#issuecomment-1749852095))
2023-10-06 01:07:40 +00:00
Joel Schlosser
f17fe89e14 Multiprocessing support for NT (#110292)
Fixes #110161

Allows NTs to be used in DataLoaders with `num_workers > 1`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110292
Approved by: https://github.com/cpuhrsch, https://github.com/albanD
2023-10-05 15:04:48 +00:00
Navid Sheik
96a3a7cc82 [pytorch] make IterableDataset of Iterable type (#109645)
Summary: Makes `IterableDataset` of `Iterable` type.

Test Plan: tests next diff in the stack are all green

Differential Revision: D49420146

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109645
Approved by: https://github.com/DanilBaibak, https://github.com/Skylion007
2023-09-25 14:18:15 +00:00
katotaisei
bcda859e34 fix typos (#108006)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108006
Approved by: https://github.com/Skylion007
2023-08-28 19:49:09 +00:00
PyTorch MergeBot
ecde622649 Revert "reseed all Generators in Dataloader's _worker_loop() -- via GC (#107131)"
This reverts commit 42625da5e1.

Reverted https://github.com/pytorch/pytorch/pull/107131 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/107131#issuecomment-1690325745))
2023-08-23 17:08:07 +00:00
Aaron Gokaslan
660e8060ad [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-22 23:16:38 +00:00
PyTorch MergeBot
d59a6864fb Revert "[BE]: Update ruff to 0.285 (#107519)"
This reverts commit 88ab3e4322.

Reverted https://github.com/pytorch/pytorch/pull/107519 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR breaks internal tests. @ezyang, can you please hep them get unblocked? It seems like one of the strings was prob accidentally modified ([comment](https://github.com/pytorch/pytorch/pull/107519#issuecomment-1688833480))
2023-08-22 19:53:32 +00:00
Aaron Gokaslan
88ab3e4322 [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-20 01:36:18 +00:00
Nicolas Hug
42625da5e1 reseed all Generators in Dataloader's _worker_loop() -- via GC (#107131)
Alternative to https://github.com/pytorch/pytorch/pull/107034, implements @ezyang 's suggestion from https://github.com/pytorch/pytorch/pull/107034#discussion_r1292857201.

This PR addresses https://fb.workplace.com/groups/pytorch.oss.dev/posts/1699944830430051 and does a bunch of stacked changes:

- Make `Generator` class support GC;this makes all `Generator` instances tracked and accessile through Python's GC.
- Use the GC to retrieve all existing Generator instances in Dataloader's `_worker_loop` and re-seed them: this extends what is already applied to the global/default Generator, which is already re-seeded.

~TODO: a bit of docs and justification, which I'll do if this PR is mergeable.~ -- Done

CC @albanD @ezyang  as previously discussed

BC-Breaking Note
-------------------

We now re-seed all `Generator` instances within the `Dataloader` workers' loop to ensure that their RNG is different across workers.
Previously, the RNG of user-defined `Generators` would be the same across workers, which could lead to wrong training procedures. This only affects user-defined `Generators`, not the default `Generator` (which was already re-seeded).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107131
Approved by: https://github.com/ezyang
2023-08-18 10:23:23 +00:00
Aaron Gokaslan
6d43c89f37 [BE]: Update Ruff to 0.0.280 (#105724)
Removes unusued loop values in python dictionary iteration. Automated fix from Ruff master

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105724
Approved by: https://github.com/ezyang, https://github.com/janeyx99
2023-07-22 23:03:34 +00:00
Justin Chu
4cc1745b13 [BE] f-stringify torch/ and scripts (#105538)
This PR is a follow up on the pyupgrade series to convert more strings to use f-strings using `flynt`.

- https://docs.python.org/3/reference/lexical_analysis.html#f-strings
- https://pypi.org/project/flynt/

Command used:

```
flynt torch/ -ll 120
flynt scripts/ -ll 120
flynt tools/ -ll 120
```

and excluded `collect_env.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105538
Approved by: https://github.com/ezyang, https://github.com/malfet
2023-07-21 19:35:24 +00:00
Justin Chu
abc1cadddb [BE] Enable ruff's UP rules and autoformat utils/ (#105424)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105424
Approved by: https://github.com/ezyang, https://github.com/malfet
2023-07-18 20:17:25 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
PyTorch MergeBot
3c5a494d7a Revert "Update mypy to 1.4.1 (#91983)"
This reverts commit 634659e262.

Reverted https://github.com/pytorch/pytorch/pull/91983 on behalf of https://github.com/malfet due to It's dependent change was reverted, so reverting this one as well, to keep CI clean ([comment](https://github.com/pytorch/pytorch/pull/91983#issuecomment-1636059709))
2023-07-14 15:59:16 +00:00
Nikita Shulga
634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00
Aaron Gokaslan
2f95a3d0fc [BE]: Apply ruff PERF fixes to torch (#104917)
Applies automated ruff fixes in the PERF modules and enables all automatic ones. I also updated ruff which applied some additional fixes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104917
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-07-11 20:45:21 +00:00
Liang
def1b57151 Update datapipe.py (#103834)
change 'dp' to 'source_dp'

Fixes #103833

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103834
Approved by: https://github.com/kit1980
2023-06-19 18:05:56 +00:00
Avi Verma
d80174e2db Do not materialize entire randperm in RandomSampler (#103339)
In our DDP training workloads, each rank was initializing a `RandomSampler` for a dataset with a length of 3.5 billion items. We noticed that when this sampler was in scope, `gc.collect` calls were taking on the order of seconds to run, which would slow down the entire training iteration. This is because when we call `torch.randperm(n).tolist()`, we create a python list of 3.5 billion items, which massively slows down the periodic mark & sweep garbage collection.

This PR swaps out the `.tolist()` call with a `.numpy()` call and manually calls `.item()` on each element as it is being requested. This has two benefits:

1. The first call to `RandomSampler::__next__` should be about twice as fast, since `.numpy` does not copy the contents of the original tensor
2. The runtime of `gc.collect()` calls no longer scales linearly with the size of the dataset passed to `RandomSampler`

I've attached some `timeit` samples to illustrate the speedups with this Pr:

```
Main (no GC):  51.72115747816861
Main (10 GC calls) 83.61965207383037
PR (no GC) 33.06403830461204
PR (10 GC calls) 33.959467427805066
```

Code
```python
from timeit import timeit

baseline_no_gc = """
import torch

n = int(1e9)
steps = n // 100

x = torch.randperm(n).tolist()
x_iter = iter(x)

for i in range(steps):
    next(x_iter)
"""

baseline_gc = """
import torch
import gc
n = int(1e9)
steps = n // 100
gc_every = steps // 10

x = torch.randperm(n).tolist()
x_iter = iter(x)

for i in range(steps):
    next(x_iter)
    if i % gc_every == 0:
        gc.collect()
"""

numpy_no_gc = """
import torch
n = int(1e9)
steps = n // 100

x = torch.randperm(n).numpy()
x_iter = (i.item() for i in x)

for i in range(steps):
    next(x_iter)
"""

numpy_gc = """
import torch
import gc
n = int(1e9)
steps = n // 100
gc_every = steps // 10

x = torch.randperm(n).numpy()
x_iter = (i.item() for i in x)

for i in range(steps):
    next(x_iter)
    if i % gc_every == 0:
        gc.collect()
"""

if __name__ == "__main__":
    print("Main (no GC): ", timeit(baseline_no_gc, number=1))
    print("Main (10 GC calls)", timeit(baseline_gc, number=1))
    print("PR (no GC)",  timeit(numpy_no_gc, number=1))
    print("PR (10 GC calls)", timeit(numpy_gc, number=1))

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103339
Approved by: https://github.com/kit1980
2023-06-16 19:25:58 +00:00
Wei Ji
f95d42b1b7 [DataPipe] Update docstring for functional form of DataPipes (#100446)
Copy the docstring from IterDataPipe and MapDataPipe classes to their functional form. Done using [`functools.update_wrapper`](https://docs.python.org/3/library/functools.html#functools.update_wrapper), xref https://stackoverflow.com/questions/6394511/python-functools-wraps-equivalent-for-classes.

See also parallel change to `.pyi` stub files at https://github.com/pytorch/pytorch/pull/100503

Fixes https://github.com/pytorch/data/issues/792 and https://github.com/weiji14/zen3geo/issues/69.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100446
Approved by: https://github.com/NivekT
2023-05-18 19:59:00 +00:00
Ramil Nugmanov
28098cae6b [DataLoader] Adding StackDataset (#101338)
Torch wrapping datasets list has:
`TensorDataset`
`ConcatDataset`
`ChainDataset`

`TensorDataset` is useful for stacking sets of tensors but can't work with objects without `.size()` method.

This PR proposes `StackDataset`, similar to `TensorDataset` but for a general case like `ConcatDataset`.

Possible usage of `StackDataset` is multimodal networks with different input like image+text or for staking non-tensor input and property to predict.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101338
Approved by: https://github.com/ejguan, https://github.com/NivekT
2023-05-18 00:57:12 +00:00
Kevin Tse
8a193c6dc5 [DataPipe] Add generated docstring to functional form DataPipe (#100503)
This PR modified the generation process of `datapipe.pyi` to include the doc strings for each DataPipe in functional form.

The new generated `.pyi` file will look like [this](https://gist.github.com/NivekT/95095f14da85a837a0727a19a5ba367c). I have confirmed the doc string will be visible in PyCharm.

You can copy this [file](https://gist.github.com/NivekT/95095f14da85a837a0727a19a5ba367c) and overwrite your local `datapipe.pyi` to validate this change as well.

Note: We need to create a similar change in TorchData to allow DataPipes in that library to show the doc strings as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100503
Approved by: https://github.com/ejguan
2023-05-10 14:06:46 +00:00
Ramil Nugmanov
a2e81a8004 [DataLoader] __getitems__ added to description of Dataset API and better supported within Subset (#100375)
DataLoader supports batched loading from Mapped Datasets.

This is the fetcher's implementation of auto-detection of batch loading support.

torch.utils.data._utils.fetch._MapDatasetFetcher
```
class _MapDatasetFetcher(_BaseDatasetFetcher):
    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
                data = self.dataset.__getitems__(possibly_batched_index)
            else:
                data = [self.dataset[idx] for idx in possibly_batched_index]
```

Description of Dataset API now shows this feature.

Additionally, Subset dataset now supports `__getitems__` if parent dataset supports it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100375
Approved by: https://github.com/ejguan, https://github.com/NivekT
2023-05-05 15:52:28 +00:00
Chase
2f41bc5465 [DataLoader] Add context to NotImplementedErrors in dataset.py (#100667)
Add helpful context message to `NotImplementedError`'s thrown by Dataset and IterableDataset, reminding users that they must implement `__getitem__`/`__iter__` in subclasses. Currently, users are presented with a bare `NotImplementedError` without describing the remedy.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100667
Approved by: https://github.com/NivekT
2023-05-05 02:16:42 +00:00
Kevin Tse
f04bb519f5 [DataPipe] Change DataPipe display name in profiler (#100042)
Script:
```python
from torchdata.datapipes.iter import IterableWrapper
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService

ls = range(16)
dp = IterableWrapper(ls).map(fn_2).map(fn_3).map(fn_4)

rs = MultiProcessingReadingService(num_workers=0, main_prefetch_cnt=0, worker_prefetch_cnt=0)
dl2 = DataLoader2(dp, reading_service=rs)

with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU]) as prof:
    for x in dl2:
        pass
```

Output before:
```
---------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                                               Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
---------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
             enumerate(DataPipe)#MapperIterDataPipe        76.37%       1.419ms       213.08%       3.959ms      80.796us            49
    enumerate(DataPipe)#IterableWrapperIterDataPipe        12.70%     236.000us        12.70%     236.000us      13.882us            17
...
```

Output after:
```
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Mapper(datapipe=Mapper, fn=fn_4, input_col=None, out...        29.79%     645.000us        99.17%       2.147ms     126.294us            17
Mapper(datapipe=IterableWrapper, fn=fn_2, input_col=...        29.24%     633.000us        42.96%     930.000us      54.706us            17
Mapper(datapipe=Mapper, fn=fn_3, input_col=None, out...        24.76%     536.000us        68.59%       1.485ms      87.353us            17
IterableWrapper(deepcopy=True, iterable=range(0, 16)...        10.58%     229.000us        10.58%     229.000us      13.471us            17
...
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100042
Approved by: https://github.com/ejguan
2023-05-03 21:36:13 +00:00
Ramil Nugmanov
3e18d3958b [DataLoader] Follow-up Fix: TypeVars of Sampler (#100409)
API backward compatibility fixed:
https://github.com/pytorch/pytorch/pull/97338#discussion_r1169164163

Mapped Dataset can accept noninteger indices from custom Samplers.

Fixes #97338

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100409
Approved by: https://github.com/ejguan, https://github.com/NivekT
2023-05-03 17:38:31 +00:00
Kevin Tse
3d8498f926 [DataLoader] Add missing documentation for arg in DataLoader (#99371)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99371
Approved by: https://github.com/janeyx99
2023-04-18 02:03:47 +00:00
erjia
29d2e4b7fa Forward fix for DataLoader to accept custom Sharding DataPipe (#97287)
Fixes #96975

Changes:
- Make sure custom ShardingDataPipe with `apply_sharding` can be used by `DataLoader`
  - Allow the `apply_sharding` function without the last argument of `sharding_group`
- Make `DataLoader` not relying on `sharding_group`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97287
Approved by: https://github.com/NivekT
2023-04-05 22:33:37 +00:00
Xuehai Pan
e6888697c4 Revisit torch._six.string_classes removal (#94709) (#97863)
Revisit `torch._six.string_classes` (which is `(str, bytes)`) removal: `isinstance(obj, string_classes) -> isinstance(obj, str)`.

Both `str` and `bytes` are `Sequence` classes.

```python
In [1]: from typing import Sequence

In [2]: issubclass(bytes, Sequence)
Out[2]: True

In [3]: issubclass(str, Sequence)
Out[3]: True
```

Re-add `bytes` to type guards like:

```python
def is_seq(obj):
    return isinstance(obj, Sequence) and not isinstance(obj, (str, bytes))
```

Ref:

- https://github.com/pytorch/pytorch/pull/94709#issuecomment-1487282912
- #97737
- #97789
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97863
Approved by: https://github.com/Skylion007, https://github.com/albanD
2023-03-30 17:02:45 +00:00
Donny You
3460b2b7d3 Add support for pin memory on custom device. (#97621)
Add support for pin memory on custom device.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97621
Approved by: https://github.com/NivekT
2023-03-29 23:45:52 +00:00
Sergii Dymchenko
46faa79e09 Simplify by using yield from in torch/utils/data (#97839)
Also see https://github.com/pytorch/pytorch/pull/97831
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97839
Approved by: https://github.com/NivekT, https://github.com/Skylion007
2023-03-29 04:51:26 +00:00
Kevin Tse
bb42104fe8 [DataLoader] Fix collation logic (#97789)
Similar to #97737, a previous auto-refactor changed how `bytes` are handled during collation, which can potentially lead to performance regression. This PR undoes that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97789
Approved by: https://github.com/albanD
2023-03-28 20:25:34 +00:00
Eric Zhang
d8cc8ffebc [DataLoader] Short circuit pin_memory recursion when operating on bytes (#97737)
Slack thread: https://pytorch.slack.com/archives/GEEQ2K4MD/p1679962409906099

I was seeing some massive (~2x) slowdowns on a job after running it on PyTorch 2.0. From some profiling in `py-spy` it looked like the pin_memory thread was doing a lot more work than before. Looking at a trace in `nsys` I saw the thread doing the forward pass having a bunch of `pthread_cond_timedwait` with GIL reacquire calls in it’s call stack, and it seemed like the thread doing the forward pass was getting blocked (waiting for the GIL) by the pin memory thread (which was holding the GIL).

After some debugging I found out the issue. If a `bytes` was passed into `pin_memory`, previously in 1.13 (before https://github.com/pytorch/pytorch/pull/94709) it would short-circuit and return here
d922c29a22/torch/utils/data/_utils/pin_memory.py (L54-L55)
since `bytes` was in `torch._six.string_classes`:
```
>>> from torch._six import string_classes
>>> string_classes
(<class 'str'>, <class 'bytes'>)
>>>
```

However after https://github.com/pytorch/pytorch/pull/94709, if a `bytes` was passed into `pin_memory` it would fall into here instead
c263bd43e8/torch/utils/data/_utils/pin_memory.py (L68-L73)
because the previous check is now doing `isinstance(data, str)` instead of `isinstance(data, (str, bytes))`!
c263bd43e8/torch/utils/data/_utils/pin_memory.py (L56-L57)

As a result, `pin_memory` gets called recursively for each element in the `bytes` leading to a ton of wasted recursion. This also explains the slowdown / GIL contention I was seeing.

This PR simply changes `isinstance(data, str)` to `isinstance(data, (str, bytes))` to match the behavior before https://github.com/pytorch/pytorch/pull/94709

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97737
Approved by: https://github.com/albanD, https://github.com/NivekT
2023-03-28 17:39:23 +00:00
Ramil Nugmanov
867b07b424 Sampler API described for customization. (#97338)
Explanation with examples of sampler customization added.

* fixed TypeVar
* removed unused init from Sampler class
* added examples for custom sampler and batch sampler
* Distributed sampler typing fixed.
* _InfiniteConstantSampler fixed

Fixes #92268

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97338
Approved by: https://github.com/NivekT
2023-03-28 06:40:38 +00:00
Kevin Tse
c5135ff2a6 [DataPipe] Fix missing imports in DataPipe interface file (#97458)
Fixes https://github.com/pytorch/data/issues/1106

Ran linter locally on `datapipes.pyi` (which is generated during installation) to confirm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97458
Approved by: https://github.com/mikaylagawarecki
2023-03-24 19:25:43 +00:00
Kazuaki Ishizaki
622a11d512 Fix typos under torch/utils directory (#97516)
This PR fixes typos in comments and messages of `.py` files under `torch/utils` directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97516
Approved by: https://github.com/ezyang
2023-03-24 16:53:39 +00:00
Tillmann Falck
939c4ae6cd [DataPipe] Add copy option to fork DataPipe (#96030)
Fixes pytorch/data#1061 and fixes pytorch/data#1032
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96030
Approved by: https://github.com/ejguan, https://github.com/NivekT
2023-03-10 17:31:56 +00:00
erjia
738cc5e644 Fix validate_input_col for nn.Module or Callable (#96213)
Forward fix the problem introduced in https://github.com/pytorch/pytorch/pull/95067

Not all `Callable` objects have `__name__` implemented. Using `repr` as the backup solution to get function name or reference.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96213
Approved by: https://github.com/NivekT
2023-03-08 01:30:17 +00:00
Felix
e6f3e16d89 Fix: validate_input_col for partial functions (#95067)
Fixes #95066

#### Proposed change:
do not call `str()` on a `Callable` to determine its name

#### Reasoning:
Please see https://github.com/pytorch/pytorch/issues/95066 for reasoning and examples

#### Effect:
* The code example given in https://github.com/pytorch/pytorch/issues/95066 now executes instantly.
* If invalid input is provided, the stacktrace now prints nicely as
  ```
  ValueError: The function foo takes 1 parameters, but 2 are required.
  ```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95067
Approved by: https://github.com/NivekT, https://github.com/ejguan
2023-03-03 21:05:07 +00:00
Xuehai Pan
1fd119948e [3/3] Update .pyi Python stub files and enable 'UFMT' linter (#95268)
Changes:

- #95200

1. Recognize `.py.in` and `.pyi.in` files as Python in VS Code for a better development experience.
2. Fix deep setting merge in `tools/vscode_settings.py`.

- #95267

3. Use `Namedtuple` rather than `namedtuple + __annotations__` for `torch.nn.utils.rnn.PackedSequence_`:

    `namedtuple + __annotations__`:

    ```python
    PackedSequence_ = namedtuple('PackedSequence_',
                                 ['data', 'batch_sizes', 'sorted_indices', 'unsorted_indices'])

    # type annotation for PackedSequence_ to make it compatible with TorchScript
    PackedSequence_.__annotations__ = {'data': torch.Tensor, 'batch_sizes': torch.Tensor,
                                       'sorted_indices': Optional[torch.Tensor],
                                       'unsorted_indices': Optional[torch.Tensor]}
    ```

    `Namedtuple`: Python 3.6+

    ```python
    class PackedSequence_(NamedTuple):
        data: torch.Tensor
        batch_sizes: torch.Tensor
        sorted_indices: Optional[torch.Tensor]
        unsorted_indices: Optional[torch.Tensor]
    ```

- => this PR: #95268

4. Sort import statements and remove unnecessary imports in `.pyi`, `.pyi.in` files.
5. Format `.pyi`, `.pyi.in` files and remove unnecessary ellipsis `...` in type stubs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95268
Approved by: https://github.com/huydhn
2023-03-01 23:50:56 +00:00
Kilian Lieret
66bea59538 Clarify meaning of pin_memory_device argument (#94349)
I don't think the docstring explaining `pin_memory_device` is very clear. If it weren't for the string type, I would not have guessed that this was about the device that is referred to in the `pin_memory` option (and honestly, it took me a few minutes before noticing the type).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94349
Approved by: https://github.com/ejguan
2023-02-15 20:40:28 +00:00
Xuehai Pan
b005ec62b9 [BE] Remove dependency on six and future (#94709)
Remove the Python 2 and 3 compatibility library [six](https://pypi.org/project/six) and [future](https://pypi.org/project/future) and `torch._six`. We only support Python 3.8+ now. It's time to retire them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94709
Approved by: https://github.com/malfet, https://github.com/Skylion007
2023-02-14 09:14:14 +00:00
Aaron Gokaslan
67d9790985 [BE] Apply almost all remaining flake8-comprehension checks (#94676)
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
2023-02-12 01:01:25 +00:00
Xuehai Pan
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
Wenlei Xie
d8f4026ebf Continue support sharding pipes in tud.datapipes.iter.grouping as deprecated (#94527)
Summary:
https://github.com/pytorch/pytorch/pull/94095 moves this into `tud.datapipes.iter.sharding`. However, since previously this is a public API, this is a BC break change.

As discussed in https://github.com/pytorch/data/pull/987#issuecomment-1422440049, we will have backward compatbile support but with deprecated warning.

Differential Revision: D43161015

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94527
Approved by: https://github.com/ejguan, https://github.com/NivekT
2023-02-10 18:42:10 +00:00
Aaron Gokaslan
1e2d82b8e4 [BE] Merge isinstance calls together (#94419)
Simplify and speeds up isinstance calls by checking for multiple types at the same time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94419
Approved by: https://github.com/ezyang
2023-02-09 00:47:26 +00:00
Aaron Gokaslan
3ce1ebb6fb Apply some safe comprehension optimizations (#94323)
Optimize unnecessary collection cast calls, unnecessary calls to list, tuple, and dict, and simplify calls to the sorted builtin. This should strictly improve speed and improve readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94323
Approved by: https://github.com/albanD
2023-02-07 23:53:46 +00:00