Build on top of https://github.com/pytorch/pytorch/pull/146821
- Moves enabling pin_memory back inside `_BaseDataLoaderIter`
- This is required for `StatefulDataloader` which leveraged `_BaseDataLoaderIter` directly and not the `Dataloader` class init
- Add a simple test for CPU only env where setting `pin_memory=True` is a no-op.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158323
Approved by: https://github.com/ramanishsingh
Co-authored-by: zeshengzong <zesheng.zong@outlook.com>
Based on https://github.com/pytorch/pytorch/pull/126376, this PR tries to update all PT callers (e.g., `Tensor.is_pinned()`, `Tensor.pin_memory()`) to not pass `device` argument.
As for `storage/untyped_storage.is_pinned()/pin_memory()`, we keep the `device` argument but passing `device` is discouraged. And if not given, the default `device` is still 'cuda' for BC.
Additionally, based on device-agnostic pin_memory, `pin_memory_device` argument of `torch.utils.data.DataLoader` is discouraged now. For BC, explictly passing this argument is still effective. If not given, the default `device` will be the current accelerator.
Fixes#124908
Relates https://github.com/pytorch/pytorch/pull/126376
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131858
Approved by: https://github.com/albanD
Co-authored-by: albanD <desmaison.alban@gmail.com>
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
For the user-defined `Mapping` type, it may contain some metadata (e.g., pytorch/tensordict#679, https://github.com/pytorch/pytorch/pull/120195#issue-2141716712). Simply use `type(mapping)({k: v for k, v in mapping.items()})` do not take this metadata into account. This PR uses `copy.copy(mapping)` to create a clone of the original collection and iteratively updates the elements in the cloned collection. This preserves the metadata in the original collection via `copy.copy(...)` rather than relying on the `__init__` method in the user-defined classes.
Reference:
- pytorch/tensordict#679
- #120195Closes#120195
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120553
Approved by: https://github.com/vmoens
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Alternative to https://github.com/pytorch/pytorch/pull/107034, implements @ezyang 's suggestion from https://github.com/pytorch/pytorch/pull/107034#discussion_r1292857201.
This PR addresses https://fb.workplace.com/groups/pytorch.oss.dev/posts/1699944830430051 and does a bunch of stacked changes:
- Make `Generator` class support GC;this makes all `Generator` instances tracked and accessile through Python's GC.
- Use the GC to retrieve all existing Generator instances in Dataloader's `_worker_loop` and re-seed them: this extends what is already applied to the global/default Generator, which is already re-seeded.
~TODO: a bit of docs and justification, which I'll do if this PR is mergeable.~ -- Done
CC @albanD @ezyang as previously discussed
BC-Breaking Note
-------------------
We now re-seed all `Generator` instances within the `Dataloader` workers' loop to ensure that their RNG is different across workers.
Previously, the RNG of user-defined `Generators` would be the same across workers, which could lead to wrong training procedures. This only affects user-defined `Generators`, not the default `Generator` (which was already re-seeded).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107131
Approved by: https://github.com/ezyang
As per request from Vision team, adding `collate` function with an extra argument of `collate_fn_map` to dispatch custom collate functions for non-collection objects and specific objects.
If the type of batch element is not present in`collate_fn_map`, it will go through all keys in the insertion order to check if the type is a subclass of the key. If so, it will invoke the corresponding collate functions.
And, `default_collate` will utilize the `collate` function with a few by default collate function for `int`, `float`, `str` and `numpy object`.
Benefit:
- Domain teams can register their own `collate` function to handle their specific type of objects
- Easier for users to extend from the `collate` function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85748
Approved by: https://github.com/NivekT, https://github.com/pmeier