Commit Graph

68 Commits

Author SHA1 Message Date
Mikayla Gawarecki
6fa2d41dc7 Add mmap option to torch.load (#102549)
Using [`nanoGPT/model.py`](https://github.com/karpathy/nanoGPT/blob/master/model.py) run

<details><summary><b>Click for script to save gpt2-xlarge (1.5B params)</b></summary>

```
# test_load_save_gpt.py
from model import GPT
import torch
import time

torch.manual_seed(5)
# gpt2-xlarge 1558M parameters
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 48
    n_head: int = 25
    n_embd: int = 1600
    dropout: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

def f():
    model = GPT(GPTConfig())
    state_dict = model.state_dict()

    start_saving = time.time()
    torch.save(state_dict, "gpt2-xlarge.pth")
    end_saving = time.time()

if __name__ == "__main__":
    f()
```
</details>

<details><summary><b>Click for script to load</b></summary>

```
# test_load_gpt.py

import torch
from model import GPT
from test_load_save_gpt import GPTConfig
import time
import argparse

def f(mmap, meta):
    device = 'meta' if meta else 'cpu'
    assign = True if meta else False
    with torch.device(device):
        model = GPT(GPTConfig())
    start_loading = time.time()
    loaded_state_dict = torch.load("gpt2-xlarge.pth", _mmap=mmap)
    end_loading = time.time()
    print(f"loading time using torch.load with mmap={mmap}: ", end_loading - start_loading)
    model.load_state_dict(loaded_state_dict, assign=assign)
    end_load_state_dict = time.time()
    print("load_state_dict time: ", end_load_state_dict - end_loading)
    model.cuda()
    end_cuda = time.time()
    print("cuda time using torch.load with mmap: ", end_cuda - end_load_state_dict)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog='load_gpt_xlarge')
    parser.add_argument('-m', '--mmap', action='store_true')
    parser.add_argument('-d', '--devicemeta', action='store_true')
    args = parser.parse_args()
    mmap = args.mmap
    meta = args.devicemeta
    f(mmap, meta)

```

</details>

`python test_load_gpt.py`

<img width="614" alt="Screenshot 2023-06-06 at 1 35 43 PM" src="https://github.com/pytorch/pytorch/assets/35276741/ee06e5b3-b610-463b-a867-df995d21af29">

`python test_load_gpt.py --mmap`
<img width="622" alt="Screenshot 2023-06-06 at 1 35 30 PM" src="https://github.com/pytorch/pytorch/assets/35276741/00d2fdd0-b1f5-4313-83dc-e540b654b2af">

If we further use the `with torch.device('meta')` context manager and pull the changes from https://github.com/pytorch/pytorch/pull/102212 that allow the model to reuse tensors from the state_dict, we have

`python test_load_gpt.py --mmap --devicemeta`
<img width="727" alt="Screenshot 2023-06-06 at 1 35 51 PM" src="https://github.com/pytorch/pytorch/assets/35276741/b50257d9-092a-49c3-acae-876ee44d009f">

\
\
Running the above in a docker container containing a build of PyTorch with RAM limited to 512mb by

1) running `make -f docker.Makefile` from `pytorch/` directory
2) `docker run -m 512m -it <image> bash`
3) docker cp `gpt2-xlarge.pth` and `test_load_gpt.py` into the image

`python test_load_gpt.py`

Docker will Kill the process due to OOM whereas

`python test_load_gpt.py --mmap --devicemeta`
<img width="635" alt="Screenshot 2023-06-06 at 1 55 48 PM" src="https://github.com/pytorch/pytorch/assets/35276741/f3820d9e-f24c-43e7-885b-3bfdf24ef8ad">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102549
Approved by: https://github.com/albanD
2023-06-09 15:49:58 +00:00
Pearu Peterson
39b04370db Preserve coalesce state in sparse COO tensor serialization (#102647)
Fixes #101186

Also, resolves the "serialization to preserve coalesced-ness" part in https://github.com/pytorch/pytorch/issues/73479

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102647
Approved by: https://github.com/mikaylagawarecki
2023-06-03 01:37:52 +00:00
Rob Guo
111358de19 Support non-ASCII characters in model file paths (#99453)
Fixes #98918

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99453
Approved by: https://github.com/albanD, https://github.com/malfet
2023-04-26 01:15:49 +00:00
Aleksei Nikiforov
87a2af6d4a Fix loading data on different encoding (#94503)
Add endianness marker when saving,
and if it doesn't match host endianness when loading data, do a byteswap.

Older data will load correctly only on systems
with same endianness it was saved on.
New data should load correctly on systems
with any endianness.

Fixes #65300
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94503
Approved by: https://github.com/kurtamohler, https://github.com/ezyang
2023-04-25 21:05:20 +00:00
Nikita Shulga
3da7e83250 Add test for pickle_module (#98373)
I.e. a regression test for https://github.com/pytorch/pytorch/issues/88438

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98373
Approved by: https://github.com/huydhn, https://github.com/kit1980
2023-04-05 13:05:05 +00:00
Nikita Shulga
c01f5118a6 Add float to list of allowed ops (#94910)
By adding `BINFLOAT` op support

Fixes https://github.com/pytorch/pytorch/issues/94670
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94910
Approved by: https://github.com/albanD
2023-02-15 23:13:21 +00:00
Xuehai Pan
046e88a291 [BE] [3/3] Rewrite super() calls in test (#94592)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94592
Approved by: https://github.com/ezyang, https://github.com/seemethere
2023-02-12 22:20:53 +00:00
Huy Do
d51ca38ef0 Run test_serialization serially (for 2xlarge runners) (#94613)
Fixes https://github.com/pytorch/pytorch/issues/92746
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94613
Approved by: https://github.com/clee2000
2023-02-11 00:15:10 +00:00
Aaron Gokaslan
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
kshitij12345
745fe35df5 [follow-up] Python Attr Serialization (#88913)
Ref: https://github.com/pytorch/pytorch/pull/81616#issuecomment-1307595402
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88913
Approved by: https://github.com/albanD
2023-01-13 17:38:51 +00:00
Aleksandar Samardžić
8612ec5b90 Implement hybrid sparse to/from dense conversions. (#90177)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90177
Approved by: https://github.com/cpuhrsch, https://github.com/pearu
2023-01-12 03:31:30 +00:00
Kurt Mohler
81b3df4fb0 Fix dtype mismatch for unallocated storage deserialization (#91285)
Fixes #90497

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91285
Approved by: https://github.com/ezyang
2022-12-27 19:31:09 +00:00
Philip Meier
7bb97c4ca4 move TypedStorage handling to assertEqual (#89557)
#85303 added a patch to `torch.testing.assert_close` to handle `torch.storage.TypedStorage`'s. This change is not reflected in the docs and is not intended for the public API. This PR removes the patch ones again and moves the behavior to `TestCase.assertEqual` instead. Meaning, `TypedStorage`'s are again not supported by the public API, but the behavior is the same for all internal use cases.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89557
Approved by: https://github.com/kurtamohler, https://github.com/mruberry
2022-12-12 23:26:00 +00:00
PyTorch MergeBot
cba96366a2 Revert "remove torch.equal usages (#89527)"
This reverts commit 4095ef8b80.

Reverted https://github.com/pytorch/pytorch/pull/89527 on behalf of https://github.com/clee2000 due to broke periodic multigpu tests 4095ef8b80 https://github.com/pytorch/pytorch/actions/runs/3592806602/jobs/6049368502
2022-12-02 21:36:13 +00:00
PyTorch MergeBot
f5fbb5001f Revert "[follow-up] Python Attr Serialization (#88913)"
This reverts commit 086b251f9a.

Reverted https://github.com/pytorch/pytorch/pull/88913 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
2022-12-02 20:14:11 +00:00
Philip Meier
4095ef8b80 remove torch.equal usages (#89527)
Preparation for the next PR in this stack: #89559.

I replaced

- `self.assertTrue(torch.equal(...))` with `self.assertEqual(..., rtol=0, atol=0, exact_device=True)`,
- the same for `self.assertFalse(...)` with `self.assertNotEqual(...)`, and
- `assert torch.equal(...)` with `torch.testing.assert_close(..., rtol=0, atol=0)` (note that we don't need to set `check_device=True` here since that is the default).

There were a few instances where the result of `torch.equal` is used directly. In that cases I've replaced with `(... == ...).all().item()` while sometimes also dropping the `.item()` depending on the context.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89527
Approved by: https://github.com/mruberry
2022-12-01 11:22:52 +00:00
Kshiteej K
086b251f9a [follow-up] Python Attr Serialization (#88913)
Ref: https://github.com/pytorch/pytorch/pull/81616#issuecomment-1307595402
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88913
Approved by: https://github.com/albanD
2022-11-29 16:46:20 +00:00
Pearu Peterson
50e2e4faf3 Sparse CSC/BSR/BSC serialization and pickle support (#89553)
Fixes https://github.com/pytorch/pytorch/issues/89497

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89553
Approved by: https://github.com/cpuhrsch
2022-11-23 20:56:48 +00:00
kshitij12345
f74946324e [fix] allow saving python attr on Tensor and Parameter via torch.save (#81616)
Fixes: https://github.com/pytorch/pytorch/issues/72129

TODO:
* [x] Fix for Parameter

Benchmark
(Measurable diff for small tensors)
```
[-------------- Save and Load --------------]
                    |  After PR  |  Before PR
1 threads: ----------------------------------
      ()            |    111.7   |     106.9
      (4, 4)        |    114.4   |     109.2
      (128, 128)    |    135.2   |     128.3
      (1024, 1024)  |   1431.9   |    1431.3

Times are in microseconds (us).
```

<details>

<summary> Benchmark Script </summary>

```python
import torch
from torch.testing._internal.common_utils import BytesIOContext
from torch.utils import benchmark
import pickle

shapes = ((), (4, 4), (128, 128), (1024, 1024))

sizes = [1, 64, 1024, 10000]
results = []

def save_load_fn(t):
    with BytesIOContext() as f:
        torch.save(t, f)
        f.seek(0)
        torch.load(f)

for shape in shapes:
    t = torch.randn(shape)
    label = 'Save and Load'
    sub_label = f'{shape}'
    results.append(benchmark.Timer(
        stmt='save_load_fn(t)',
        globals={'t': t, 'save_load_fn':save_load_fn},
        label=label,
        sub_label=sub_label,
        description='Before PR',
    ).blocked_autorange(min_run_time=2))

compare = benchmark.Compare(results)
compare.print()

with open('before_pr.pkl', 'wb') as f:
    pickle.dump(results, f)

# with open('after_pr.pkl', 'rb') as f:
#     after_pr = pickle.load(f)

# with open('before_pr.pkl', 'rb') as f:
#     before_pr = pickle.load(f)

# compare = benchmark.Compare(after_pr + before_pr)
# compare.print()
```

</details>

NOTE : **BC-Breaking** : After this PR, all tensors (also regular tensors) will be serialised using `_rebuild_from_type_v2`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81616
Approved by: https://github.com/albanD, https://github.com/kurtamohler
2022-11-11 21:11:12 +00:00
Kurt Mohler
89a326ff7e Explicitly check filelike arg of torch.save (#88867)
Fixes #88793

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88867
Approved by: https://github.com/ezyang
2022-11-11 16:57:08 +00:00
kshitij12345
d15a6b0c97 Error on ZeroTensor serialization (#88803)
Follow-up : https://github.com/pytorch/pytorch/pull/88182#issuecomment-1308628415

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88803
Approved by: https://github.com/anjali411
2022-11-11 08:51:29 +00:00
kshitij12345
eb9b156019 [fix] MathBits: serialization (#88182)
Fixes #81690

TODO:

* [x] C++ Unpickler Fix (locally tested pickled in Python and unpickled in C++)
* [x] C++ Pickler Fix (locally tested pickled in C++ and unpickled in Python)
* [x] Do quant_tensor, sparse_tensor, etc require similar changes? (Sparse and Quant don't need this)
* [x] Add Comments
* [x] How to make sure C++ and Python are in sync? (Functions in `pickler.h` help in getting and setting Tensor Metadata (math-bits for now) on a tensor. They are the only place which should handle this.)

Notes:
Quant Tensor don't support complex dtypes and for float they segfault with `_neg_view` : https://github.com/pytorch/pytorch/issues/88484

Sparse Tensor:
```python
>>> a = torch.tensor([[0, 2.], [3j, 0]]).to_sparse()
>>> a.conj().is_conj()
False
>>> a._neg_view()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: Cannot access storage of SparseTensorImpl
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88182
Approved by: https://github.com/ezyang, https://github.com/anjali411
2022-11-09 17:15:12 +00:00
PyTorch MergeBot
78a0ca29d9 Revert "[fix] allow saving python attr on Tensor and Parameter via torch.save (#81616)"
This reverts commit 54b6188cc6.

Reverted https://github.com/pytorch/pytorch/pull/81616 on behalf of https://github.com/mehtanirav due to Internal publishing is broken
2022-11-07 18:51:16 +00:00
Kshiteej K
54b6188cc6 [fix] allow saving python attr on Tensor and Parameter via torch.save (#81616)
Fixes: https://github.com/pytorch/pytorch/issues/72129

TODO:
* [x] Fix for Parameter

Benchmark
(Measurable diff for small tensors)
```
[-------------- Save and Load --------------]
                    |  After PR  |  Before PR
1 threads: ----------------------------------
      ()            |    111.7   |     106.9
      (4, 4)        |    114.4   |     109.2
      (128, 128)    |    135.2   |     128.3
      (1024, 1024)  |   1431.9   |    1431.3

Times are in microseconds (us).
```

<details>

<summary> Benchmark Script </summary>

```python
import torch
from torch.testing._internal.common_utils import BytesIOContext
from torch.utils import benchmark
import pickle

shapes = ((), (4, 4), (128, 128), (1024, 1024))

sizes = [1, 64, 1024, 10000]
results = []

def save_load_fn(t):
    with BytesIOContext() as f:
        torch.save(t, f)
        f.seek(0)
        torch.load(f)

for shape in shapes:
    t = torch.randn(shape)
    label = 'Save and Load'
    sub_label = f'{shape}'
    results.append(benchmark.Timer(
        stmt='save_load_fn(t)',
        globals={'t': t, 'save_load_fn':save_load_fn},
        label=label,
        sub_label=sub_label,
        description='Before PR',
    ).blocked_autorange(min_run_time=2))

compare = benchmark.Compare(results)
compare.print()

with open('before_pr.pkl', 'wb') as f:
    pickle.dump(results, f)

# with open('after_pr.pkl', 'rb') as f:
#     after_pr = pickle.load(f)

# with open('before_pr.pkl', 'rb') as f:
#     before_pr = pickle.load(f)

# compare = benchmark.Compare(after_pr + before_pr)
# compare.print()
```

</details>

NOTE : **BC-Breaking** : After this PR, all tensors (also regular tensors) will be serialised using `_rebuild_from_type_v2`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81616
Approved by: https://github.com/albanD, https://github.com/kurtamohler
2022-11-03 09:57:47 +00:00
Nikita Shulga
caaf37a111 Fix PyTorchStreamWriter exception handling (#88128)
Avoid double exception in destructor if attempting to serialize to
python object that does not have `write` method

Use `Finalizer` class in `PyTorchStreamWriter::writeEndOfFile()` to a
always set `finailized_` property even if excretion occurs. (as there
isn't much one can do at this point)

Add expicit check for the attribue to `_open_zipfile_writer_buffer` and
add unitests

Modernize code a bit by using Python-3 `super()` method

Fixes https://github.com/pytorch/pytorch/issues/87997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88128
Approved by: https://github.com/albanD
2022-10-31 23:38:03 +00:00
Nikita Shulga
961ebca225 Add weights_only option to torch.load (#86812)
This addresses the security issue in default Python's `unpickler` that allows arbitrary code execution while unpickling.
Restrict classes allowed to be unpicked to in `None`, `int`, `bool`, `str`, `float`, `list`, `tuple`, `dict`/`OrderedDict` as well as `torch.Size`, `torch.nn.Param` as well as  `torch.Tensor` and `torch.Storage` variants.

Defaults `weights_only` is set to `False`,  but allows global override to safe only load via `TORCH_FORCE_WEIGHTS_ONLY_LOAD` environment variable.

To some extent, addresses https://github.com/pytorch/pytorch/issues/52596
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86812
Approved by: https://github.com/ezyang
2022-10-21 01:09:50 +00:00
Nikita Shulga
4a533f1215 Tweak several test serialization to store models state_dict (#87143)
Namely, change:
- `test_meta_serialization`
- `test_serialization_2gb_file`
- `test_pathlike_serialization`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87143
Approved by: https://github.com/ezyang
2022-10-19 20:51:32 +00:00
Kurt Mohler
14d0296e5c Rename _Typed/_UntypedStorage to Typed/UntypedStorage and update docs (#82438)
### Description

Since the major changes for `_TypedStorage` and `_UntypedStorage` are now complete, they can be renamed to be public.

`TypedStorage._untyped()` is renamed to `TypedStorage.untyped()`.

Documentation for storages is improved as well.

### Issue
Fixes #82436

### Testing
N/A

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82438
Approved by: https://github.com/ezyang
2022-07-30 19:37:08 +00:00
PyTorch MergeBot
da87fa684c Revert "[fix] allow saving python attr on Tensor and Parameter via torch.save (#81616)"
This reverts commit f3f8d96ea6.

Reverted https://github.com/pytorch/pytorch/pull/81616 on behalf of https://github.com/jeanschmidt due to breaking internal builds
2022-07-21 10:46:24 +00:00
kshitij12345
f3f8d96ea6 [fix] allow saving python attr on Tensor and Parameter via torch.save (#81616)
Fixes: https://github.com/pytorch/pytorch/issues/72129

TODO:
* [x] Fix for Parameter

Benchmark
(Measurable diff for small tensors)
```
[-------------- Save and Load --------------]
                    |  After PR  |  Before PR
1 threads: ----------------------------------
      ()            |    111.7   |     106.9
      (4, 4)        |    114.4   |     109.2
      (128, 128)    |    135.2   |     128.3
      (1024, 1024)  |   1431.9   |    1431.3

Times are in microseconds (us).
```

<details>

<summary> Benchmark Script </summary>

```python
import torch
from torch.testing._internal.common_utils import BytesIOContext
from torch.utils import benchmark
import pickle

shapes = ((), (4, 4), (128, 128), (1024, 1024))

sizes = [1, 64, 1024, 10000]
results = []

def save_load_fn(t):
    with BytesIOContext() as f:
        torch.save(t, f)
        f.seek(0)
        torch.load(f)

for shape in shapes:
    t = torch.randn(shape)
    label = 'Save and Load'
    sub_label = f'{shape}'
    results.append(benchmark.Timer(
        stmt='save_load_fn(t)',
        globals={'t': t, 'save_load_fn':save_load_fn},
        label=label,
        sub_label=sub_label,
        description='Before PR',
    ).blocked_autorange(min_run_time=2))

compare = benchmark.Compare(results)
compare.print()

with open('before_pr.pkl', 'wb') as f:
    pickle.dump(results, f)

# with open('after_pr.pkl', 'rb') as f:
#     after_pr = pickle.load(f)

# with open('before_pr.pkl', 'rb') as f:
#     before_pr = pickle.load(f)

# compare = benchmark.Compare(after_pr + before_pr)
# compare.print()
```

</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81616
Approved by: https://github.com/albanD
2022-07-20 18:45:33 +00:00
albanD
1afb804f26 Improve wrapper subclass detection for serialization (#81105)
Fixes https://github.com/pytorch/pytorch/issues/80983

Also fix a small bug uncovered by the new test where creating memory_view for 0-sized inputs is not valid and is now skipped

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81105
Approved by: https://github.com/ezyang
2022-07-11 14:02:37 +00:00
Alban Desmaison
e4d5801e36 Make sure requires_grad is propagated for all backend
The if statement is not strictly necessary but that avoid having to call this function if we don't need it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76256
Approved by: https://github.com/ezyang, https://github.com/soulitzer
2022-04-25 19:31:24 +00:00
Nikita Shulga
bfac65dfe5
[testing] Update dispatch macros (#74977)
This PR is reland of #74289 
Co-authored-by: Khushi Agrawal <khushiagrawal411@gmail.com>
2022-03-30 14:13:21 -07:00
PyTorch MergeBot
2e4152b118 Revert "[testing] Update dispatch macros"
This reverts commit eed19a0f38.

Reverted https://github.com/pytorch/pytorch/pull/74289 on behalf of https://github.com/malfet
2022-03-30 19:52:37 +00:00
Khushi Agrawal
eed19a0f38 [testing] Update dispatch macros
Hi,
This PR is the follow-up PR of #71561. (the previous PR had a couple of merge conflicts and was reverted, this PR resolves that).
Please take a look. Thanks!

cc: @pmeier @mruberry @kshitij12345
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74289
Approved by: https://github.com/pmeier, https://github.com/mruberry
2022-03-30 16:10:16 +00:00
Nikita Shulga
ef066f0832 Revert D34856571: [pytorch][PR] Replace get_all_ type macros with the ATen dispatch macros.
Test Plan: revert-hammer

Differential Revision:
D34856571 (3ded7b1da3)

Original commit changeset: 0dca038bcad5

Original Phabricator Diff: D34856571 (3ded7b1da3)

fbshipit-source-id: 594553fa0b710d78beba59d5d2b646f1f1270386
(cherry picked from commit 8090eb9b12dcf452a9e7dc01792a66fb91b563b6)
2022-03-15 22:07:11 +00:00
Khushi Agrawal
3ded7b1da3 Replace get_all_ type macros with the ATen dispatch macros. (#71561)
Summary:
Hi, Team!
The PR is motivated from https://github.com/pytorch/pytorch/pull/71153#discussion_r782446738. It aims to replace `get_all` type macros with the ATen dispatch macros.

The files it iterates over are: (Thanks, Lezcano, for the idea!!)

<details>
<summary>

`test/test_autograd.py`</summary>

<p>

```python
43:from torch.testing._internal.common_dtype import get_all_dtypes
8506:        floating_dt = [dt for dt in get_all_dtypes() if dt.is_floating_point]
```

</p>
</details>

<details>
<summary>

`test/test_binary_ufuncs.py`</summary>

<p>

```python
26:    all_types_and_complex_and, integral_types_and, get_all_dtypes, get_all_int_dtypes, get_all_math_dtypes,
27:    get_all_complex_dtypes, get_all_fp_dtypes,
935:    dtypes(*get_all_dtypes(include_bool=False, include_complex=False))
1035:    dtypes(*get_all_dtypes(
1488:    dtypes(*(get_all_dtypes(include_bool=False, include_bfloat16=False)))
1879:    dtypes(*product(get_all_dtypes(include_complex=False), get_all_dtypes(include_complex=False)))
1887:    dtypes(*(get_all_int_dtypes() + [torch.bool]))
1913:    dtypes(*(get_all_fp_dtypes()))
1941:    dtypes(*(get_all_fp_dtypes()))
1977:    dtypes(*product(get_all_complex_dtypes(), get_all_dtypes()))
2019:    dtypes(*product(get_all_fp_dtypes(), get_all_fp_dtypes()))
2048:    dtypes(*get_all_dtypes())
2110:    dtypes(*product(get_all_dtypes(include_complex=False),
2111:                     get_all_dtypes(include_complex=False)))
2128:            types = [torch.bool, torch.bfloat16] + get_all_int_dtypes()
2173:        if dtypes[1] in get_all_fp_dtypes():
2178:    dtypes(*product(get_all_fp_dtypes(),
2179:                     get_all_fp_dtypes()))
2260:    dtypesIfCUDA(*set(get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128})
2261:    dtypes(*set(get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128})
2273:    dtypesIfCUDA(*set(get_all_math_dtypes('cuda')) - {torch.complex64, torch.complex128})
2274:    dtypes(*set(get_all_math_dtypes('cpu')) - {torch.complex64, torch.complex128})
2307:    dtypes(*get_all_math_dtypes('cpu'))
2319:    dtypes(*get_all_fp_dtypes(include_bfloat16=False))
2331:    dtypes(*get_all_int_dtypes())
2356:    dtypes(*get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False))
2393:        if dtype in get_all_int_dtypes():
2614:    dtypes(*get_all_dtypes())
2624:    dtypes(*tuple(itertools.combinations_with_replacement(get_all_dtypes(), 2)))
2806:    dtypes(*list(product(get_all_dtypes(include_complex=False),
2807:                          get_all_dtypes(include_complex=False))))
2866:    dtypes(*list(product(get_all_complex_dtypes(),
2867:                          get_all_complex_dtypes())))
2902:    dtypes(*product(get_all_dtypes(), get_all_dtypes()))
2906:    dtypes(*product(get_all_dtypes(), get_all_dtypes()))
2910:    dtypes(*product(get_all_dtypes(), get_all_dtypes()))
3019:        dtypes = [torch.float, torch.double] + get_all_complex_dtypes()
3221:    dtypes(*get_all_dtypes(include_complex=False))
3407:    dtypes(*list(product(get_all_dtypes(include_bool=False),
3408:                          get_all_dtypes(include_bool=False))))
3504:    dtypes(*product(get_all_dtypes(include_complex=False, include_bfloat16=False),
3505:                     get_all_dtypes(include_complex=False, include_bfloat16=False)))
3516:            if x.dtype in get_all_int_dtypes() + [torch.bool]:
3643:    dtypes(*product(get_all_dtypes(include_complex=False,
3645:                     get_all_dtypes(include_complex=False,
```

</p>
</details>

<details>
<summary>

`test/test_complex.py`</summary>

<p>

```python
6:from torch.testing._internal.common_dtype import get_all_complex_dtypes
11:    dtypes(*get_all_complex_dtypes())
```

</p>
</details>

<details>
<summary>

`test/test_foreach.py`</summary>

<p>

```python
18:    get_all_dtypes, get_all_int_dtypes, get_all_complex_dtypes, get_all_fp_dtypes,
142:            if dtype in get_all_int_dtypes():
179:            disable_fastpath = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool]
201:            disable_fastpath = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool]
205:                disable_fastpath |= dtype in get_all_int_dtypes() + [torch.bool]
211:                disable_fastpath |= dtype not in get_all_complex_dtypes()
241:                bool_int_div = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool]
246:                    disable_fastpath |= dtype in get_all_int_dtypes() + [torch.bool]
248:                    disable_fastpath |= dtype not in get_all_complex_dtypes()
250:                    disable_fastpath |= True and dtype not in get_all_complex_dtypes()
307:        disable_fastpath = dtype in get_all_int_dtypes() + [torch.bool]
365:        if opinfo.name == "_foreach_abs" and dtype in get_all_complex_dtypes():
376:    ops(foreach_unary_op_db, dtypes=get_all_dtypes())
393:         dtypes=get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
401:    ops(foreach_minmax_op_db, dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=True))
426:            if ord in (1, 2) and dtype in torch.testing.get_all_fp_dtypes():
439:    dtypes(*get_all_dtypes())
449:    ops(foreach_binary_op_db, dtypes=get_all_dtypes())
481:    ops(foreach_binary_op_db, dtypes=get_all_dtypes())
536:            if dtype in get_all_int_dtypes() + [torch.bool] and foreach_op == torch._foreach_div:
545:    ops(foreach_binary_op_db, dtypes=get_all_dtypes())
637:    ops(foreach_pointwise_op_db, allowed_dtypes=get_all_fp_dtypes(include_half=False, include_bfloat16=False))
```

</p>
</details>

<details>
<summary>

`test/test_linalg.py`</summary>

<p>

```python
29:    all_types, floating_types, floating_and_complex_types, get_all_dtypes, get_all_int_dtypes, get_all_complex_dtypes,
30:    get_all_fp_dtypes,
111:    dtypes(*(get_all_dtypes()))
794:        float_and_complex_dtypes = get_all_fp_dtypes() + get_all_complex_dtypes()
807:    dtypes(*(get_all_int_dtypes()))
828:    dtypes(*(get_all_fp_dtypes() + get_all_complex_dtypes()))
841:        if dtype in get_all_complex_dtypes():
844:    dtypes(*itertools.product(get_all_dtypes(),
845:                               get_all_dtypes()))
855:        for dtypes0, dtypes1, dtypes2 in product(get_all_dtypes(), repeat=3):
5607:                  *get_all_fp_dtypes(include_half=not CUDA9, include_bfloat16=(CUDA11OrLater and SM53OrLater)))
5608:    dtypes(*(set(get_all_dtypes()) - {torch.half, torch.bool}))
5644:    dtypes(*(get_all_complex_dtypes() + get_all_fp_dtypes()))
6255:    dtypesIfCUDA(*get_all_complex_dtypes(),
6256:                  *get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)),
6292:    dtypesIfCUDA(*get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater))))
6323:    dtypesIfCUDA(*get_all_complex_dtypes(),
6324:                  *get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater))))
6325:    dtypes(*get_all_complex_dtypes(), *get_all_fp_dtypes())
6358:    dtypesIfCUDA(*([torch.float, torch.double] + get_all_complex_dtypes()))
6556:    dtypes(*get_all_fp_dtypes(), *get_all_complex_dtypes())
6668:    dtypes(*get_all_fp_dtypes(), *get_all_complex_dtypes())
6741:    dtypes(*get_all_fp_dtypes(), *get_all_complex_dtypes())
```

</p>
</details>

<details>
<summary>

`test/test_nn.py`</summary>

<p>

```python
37:from torch.testing._internal.common_dtype import integral_types, get_all_fp_dtypes, get_all_math_dtypes
50:    onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, skipMeta, get_all_device_types, \
8862:                for device in get_all_device_types():
9629:            for dt1 in get_all_math_dtypes(device):
9630:                for dt2 in get_all_math_dtypes(device):
9631:                    for dt3 in get_all_math_dtypes(device):
9648:            for input_dtype in get_all_math_dtypes(device):
9664:            for input_dtype in get_all_math_dtypes(device):
13015:    dtypes(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
13034:    dtypes(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
13159:    dtypes(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
17400:    dtypesIfCUDA(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
17768:    dtypesIfCUDA(*get_all_fp_dtypes())
17773:    dtypesIfCUDA(*get_all_fp_dtypes())
17778:    dtypesIfCUDA(*get_all_fp_dtypes())
17783:    dtypesIfCUDA(*get_all_fp_dtypes())
17788:    dtypesIfCUDA(*get_all_fp_dtypes())
17793:    dtypesIfCUDA(*get_all_fp_dtypes())
17798:    dtypesIfCUDA(*get_all_fp_dtypes())
17963:    dtypesIfCUDA(*get_all_fp_dtypes())
17977:    dtypesIfCUDA(*get_all_fp_dtypes())
18684:    def test_cross_entropy_loss_prob_target_all_reductions(self, device):
```

</p>
</details>

<details>
<summary>

`test/test_numpy_interop.py`</summary>

<p>

```python
12:from torch.testing._internal.common_dtype import get_all_dtypes
399:    dtypes(*get_all_dtypes())
```

</p>
</details>

<details>
<summary>

`test/test_ops.py`</summary>

<p>

```python
12:from torch.testing._internal.common_dtype import floating_and_complex_types_and, get_all_dtypes
86:        for dtype in get_all_dtypes():
```

</p>
</details>

<details>
<summary>

`test/test_reductions.py`</summary>

<p>

```python
16:    get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_complex_dtypes, get_all_fp_dtypes,
360:         allowed_dtypes=get_all_dtypes(include_bfloat16=False))
366:         allowed_dtypes=get_all_dtypes(include_bfloat16=False))
394:         allowed_dtypes=get_all_dtypes(include_bfloat16=False))
750:        for dtype in [dtype for dtype in get_all_math_dtypes('cpu') if dtype != torch.float16]:
1404:    dtypes(*get_all_dtypes(include_bool=False, include_complex=False))
1457:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) +
1458:              get_all_complex_dtypes()))
1465:            return dtype in get_all_int_dtypes()
1494:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False)))
1501:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False)))
1507:    dtypes(*(get_all_complex_dtypes()))
1514:        dtypes = list(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False))
1523:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False)))
1531:        if dtype in get_all_fp_dtypes():
1608:    dtypes(*(get_all_dtypes(include_half=True, include_bfloat16=False,
1837:    dtypes(*get_all_dtypes(include_bool=False, include_complex=False))
1855:    dtypes(*(set(get_all_dtypes(include_bool=False, include_complex=False)) - {torch.uint8}))
3219:        for dtype in get_all_dtypes(include_half=True, include_bfloat16=False,
```

</p>
</details>

<details>
<summary>

`test/test_serialization.py`</summary>

<p>

```python
26:from torch.testing._internal.common_dtype import get_all_dtypes
586:        for device, dtype in product(devices, get_all_dtypes()):
589:            for other_dtype in get_all_dtypes():
```

</p>
</details>

<details>
<summary>

`test/test_shape_ops.py`</summary>

<p>

```python
18:from torch.testing._internal.common_dtype import get_all_dtypes
230:    dtypes(*get_all_dtypes(include_complex=False, include_bool=False, include_half=False,
232:    dtypesIfCUDA(*get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False))
344:    dtypes(*get_all_dtypes())
443:    dtypes(*get_all_dtypes())
461:    dtypes(*get_all_dtypes())
570:    dtypes(*get_all_dtypes(include_complex=False))
```

</p>
</details>

<details>
<summary>

`test/test_sort_and_select.py`</summary>

<p>

```python
12:    all_types, all_types_and, floating_types_and, get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes,
136:    dtypes(*set(get_all_dtypes()) - {torch.bool, torch.complex64, torch.complex128})
231:    dtypes(*set(get_all_dtypes()) - {torch.bool, torch.complex64, torch.complex128})
296:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
647:    dtypesIfCUDA(*get_all_fp_dtypes())
678:    dtypesIfCUDA(*(get_all_dtypes(include_complex=False,
682:    dtypes(*(get_all_dtypes(include_complex=False, include_bool=False, include_half=False, include_bfloat16=False)))
739:    dtypesIfCPU(*set(get_all_dtypes()) - {torch.complex64, torch.complex128})
740:    dtypes(*set(get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128})
799:    dtypesIfCPU(*set(get_all_dtypes()) - {torch.complex64, torch.complex128})
800:    dtypes(*set(get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128})
```

</p>
</details>

<details>
<summary>

`test/test_sparse.py`</summary>

<p>

```python
20:from torch.testing import get_all_complex_dtypes, get_all_fp_dtypes
29:    floating_and_complex_types, floating_and_complex_types_and, get_all_dtypes, get_all_int_dtypes,
1963:            return dtype in get_all_int_dtypes()
1994:    dtypes(*get_all_dtypes(include_bool=False, include_half=False,
2103:            return dtype in get_all_int_dtypes()
2138:    dtypes(*get_all_dtypes(include_bool=False, include_half=False,
2626:        all_sparse_dtypes = get_all_dtypes(include_complex=True)
2633:        all_sparse_dtypes = get_all_dtypes(include_complex=True)
3230:    dtypes(*get_all_complex_dtypes(),
3231:            *get_all_fp_dtypes(include_half=False, include_bfloat16=False))
3234:                  *get_all_fp_dtypes(
```

</p>
</details>

<details>
<summary>

`test/test_sparse_csr.py`</summary>

<p>

```python
7:from torch.testing import get_all_complex_dtypes, get_all_fp_dtypes, floating_and_complex_types, make_tensor
17:from torch.testing._internal.common_dtype import floating_types, get_all_dtypes
120:    dtypes(*get_all_dtypes())
133:    dtypes(*get_all_dtypes())
150:    dtypes(*get_all_dtypes())
180:    dtypes(*get_all_dtypes())
201:    dtypes(*get_all_dtypes())
210:    dtypes(*get_all_dtypes())
225:    dtypes(*get_all_dtypes())
244:    dtypes(*get_all_dtypes())
263:    dtypes(*get_all_dtypes())
285:    dtypes(*get_all_dtypes())
411:    dtypes(*get_all_dtypes())
482:    dtypes(*get_all_dtypes())
502:    dtypes(*get_all_dtypes())
562:    dtypes(*get_all_dtypes())
588:    dtypesIfCUDA(*get_all_complex_dtypes(),
589:                  *get_all_fp_dtypes(include_half=SM53OrLater, include_bfloat16=SM80OrLater))
745:    dtypesIfCUDA(*get_all_complex_dtypes(),
746:                  *get_all_fp_dtypes(include_half=SM53OrLater and TEST_CUSPARSE_GENERIC,
765:    dtypesIfCUDA(*get_all_complex_dtypes(),
766:                  *get_all_fp_dtypes(include_half=SM53OrLater and TEST_CUSPARSE_GENERIC,
801:                  *torch.testing.get_all_fp_dtypes(include_bfloat16=SM80OrLater,
841:                  *torch.testing.get_all_fp_dtypes(include_bfloat16=SM80OrLater,
1182:    dtypes(*get_all_dtypes())
1276:    dtypes(*get_all_dtypes(include_bool=False, include_half=False, include_bfloat16=False))
1286:    dtypes(*get_all_dtypes())
```

</p>
</details>

<details>
<summary>

`test/test_tensor_creation_ops.py`</summary>

<p>

```python
21:    onlyCUDA, skipCPUIf, dtypesIfCUDA, skipMeta, get_all_device_types)
23:    get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes
150:        for dt in get_all_dtypes():
160:        for dt in get_all_dtypes():
314:        dtypes = [dtype for dtype in get_all_dtypes() if dtype != torch.bfloat16]
1012:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) +
1013:              get_all_complex_dtypes()))
1032:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) +
1033:              get_all_complex_dtypes()))
1050:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) +
1051:              get_all_complex_dtypes()))
1745:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
1779:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
1868:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
1926:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
1954:            do_test_empty_full(self, get_all_math_dtypes('cpu'), torch.strided, torch_device)
1956:            do_test_empty_full(self, get_all_math_dtypes('cpu'), torch.strided, None)
1957:            do_test_empty_full(self, get_all_math_dtypes('cpu'), torch.strided, torch_device)
2538:        for device in get_all_device_types():
2645:        for dtype in get_all_dtypes():
2678:    dtypes(*(get_all_fp_dtypes(include_half=False, include_bfloat16=False) +
2679:              get_all_complex_dtypes()))
2716:    dtypes(*get_all_fp_dtypes(include_half=False, include_bfloat16=False))
2827:            for dt in get_all_dtypes():
2913:    dtypes(*get_all_dtypes(include_bool=False, include_half=False))
2914:    dtypesIfCUDA(*get_all_dtypes(include_bool=False, include_half=True))
3028:    dtypes(*(get_all_fp_dtypes() + get_all_complex_dtypes()))
3033:    dtypes(*(get_all_fp_dtypes() + get_all_complex_dtypes()))
3074:    dtypes(*get_all_dtypes(include_bool=False, include_half=False, include_complex=False))
3075:    dtypesIfCUDA(*((get_all_int_dtypes() + [torch.float32, torch.float16, torch.bfloat16])
3077:                    else get_all_dtypes(include_bool=False, include_half=True, include_complex=False)))
3873:    dtypes(*get_all_dtypes())
3884:    dtypes(*get_all_dtypes(include_bool=False))
3916:            for other in get_all_dtypes():
3922:    dtypes(*get_all_dtypes())
3932:    dtypes(*get_all_dtypes(include_bool=False))
3955:    dtypes(*get_all_dtypes(include_bool=False))
3961:    dtypes(*get_all_dtypes(include_bool=False))
3965:    dtypes(*get_all_dtypes())
```

</p>
</details>

<details>
<summary>

`test/test_testing.py`</summary>

<p>

```python
25:from torch.testing._internal.common_dtype import get_all_dtypes
31:    dtypes(*(get_all_dtypes(include_half=True, include_bfloat16=False,
```

</p>
</details>

<details>
<summary>

`test/test_torch.py`</summary>

<p>

```python
51:    expectedAlertNondeterministic, get_all_device_types, skipXLA)
57:    get_all_fp_dtypes, get_all_int_dtypes, get_all_math_dtypes, get_all_dtypes, get_all_complex_dtypes
296:            for d in get_all_device_types():
323:            for device in get_all_device_types():
324:                for dt1 in get_all_dtypes():
325:                    for dt2 in get_all_dtypes():
343:            all_dtypes = get_all_dtypes()
350:            all_dtypes = get_all_dtypes()
781:            for dtype in get_all_dtypes():
986:            for device in get_all_device_types():
1017:            for device in get_all_device_types():
1018:                for dtype in get_all_math_dtypes(device):
2792:            for device in get_all_device_types():
3186:    dtypes(*get_all_dtypes())
3195:        for error_dtype in get_all_dtypes():
3203:    dtypes(*get_all_dtypes())
3212:        for error_dtype in get_all_dtypes():
4539:    dtypes(*get_all_fp_dtypes())
4545:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
4577:    dtypes(*get_all_fp_dtypes(include_half=False, include_bfloat16=False))
4578:    dtypesIfCPU(*(get_all_fp_dtypes(include_half=False, include_bfloat16=True)))
4579:    dtypesIfCUDA(*(get_all_fp_dtypes(include_bfloat16=False)))
4599:    dtypes(*(get_all_fp_dtypes(include_half=False, include_bfloat16=False)))
4600:    dtypesIfCPU(*(get_all_dtypes(include_half=False, include_bfloat16=False, include_complex=False)))
4601:    dtypesIfCUDA(*(get_all_dtypes(include_bfloat16=False, include_complex=False)))
4613:        for p_dtype in get_all_fp_dtypes(include_half=device.startswith('cuda'), include_bfloat16=False):
4628:    dtypes(*(get_all_fp_dtypes(include_half=False, include_bfloat16=False)))
4629:    dtypesIfCUDA(*(get_all_fp_dtypes(include_bfloat16=False)))
4640:    dtypes(*get_all_fp_dtypes())
4723:    dtypes(*get_all_fp_dtypes())
4735:    dtypes(*get_all_fp_dtypes(include_bfloat16=False))
4736:    dtypesIfCUDA(*get_all_fp_dtypes())
4747:    dtypes(*get_all_fp_dtypes())
4761:    dtypes(*get_all_fp_dtypes())
4771:    dtypes(*get_all_fp_dtypes())
4792:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
5302:    dtypes(*get_all_dtypes(include_bfloat16=False))
5322:    dtypes(*get_all_dtypes(include_half=False, include_bfloat16=False))
5323:    dtypesIfCPU(*get_all_dtypes(include_bfloat16=False))
5324:    dtypesIfCUDA(*get_all_dtypes(include_bfloat16=False))
5591:        for dt in get_all_dtypes():
5611:        for dt in get_all_dtypes():
5678:        for dt in get_all_dtypes():
5696:    dtypesIfCUDA(*set(get_all_math_dtypes('cuda')))
5697:    dtypes(*set(get_all_math_dtypes('cpu')))
5746:    dtypes(*get_all_dtypes())
5780:    dtypes(*get_all_dtypes())
5885:    dtypes(*get_all_dtypes())
5902:    dtypes(*get_all_dtypes())
5945:    dtypes(*get_all_dtypes())
5979:    dtypes(*get_all_dtypes(include_bool=False))
6049:    dtypes(*get_all_dtypes(include_bool=False))
6092:    dtypes(*(get_all_fp_dtypes(include_bfloat16=False, include_half=False) +
6093:              get_all_complex_dtypes()))
6094:    dtypesIfCPU(*get_all_dtypes())
6095:    dtypesIfCUDA(*get_all_dtypes())
6122:    dtypes(*(get_all_fp_dtypes(include_bfloat16=False, include_half=False) +
6123:              get_all_complex_dtypes()))
6124:    dtypesIfCPU(*get_all_dtypes())
6125:    dtypesIfCUDA(*get_all_dtypes())
6163:    dtypes(*(get_all_fp_dtypes(include_bfloat16=False, include_half=False) +
6164:              get_all_complex_dtypes()))
6165:    dtypesIfCPU(*get_all_dtypes())
6166:    dtypesIfCUDA(*get_all_dtypes())
6190:    dtypes(*(get_all_complex_dtypes() +
6191:              get_all_int_dtypes()))
6238:    dtypes(*get_all_dtypes())
6323:    dtypes(*get_all_dtypes())
6389:    dtypes(*product(get_all_dtypes(), (torch.uint8, torch.bool)))
6699:    dtypesIfCUDA(*set(get_all_math_dtypes('cuda')))
6700:    dtypes(*set(get_all_math_dtypes('cpu')))
7452:    dtypes(*get_all_dtypes(include_bool=False))
7461:    dtypes(*get_all_dtypes(include_bool=False))
7477:    dtypes(*get_all_dtypes(include_bool=False))
7496:    dtypes(*get_all_dtypes(include_bool=False))
7538:    dtypes(*get_all_dtypes(include_bool=False))
8162:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes() +
8163:              get_all_complex_dtypes()))
8175:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes() +
8176:              get_all_complex_dtypes()))
```

</p>
</details>

<details>
<summary>

`test/test_type_promotion.py`</summary>

<p>

```python
14:    get_all_dtypes, get_all_math_dtypes, get_all_int_dtypes, get_all_fp_dtypes
187:        for dtype in get_all_dtypes():
262:        dtypes1 = get_all_math_dtypes('cuda')
263:        dtypes2 = get_all_math_dtypes(device)
339:    dtypes(*itertools.product(get_all_dtypes(), get_all_dtypes()))
468:            for dt1 in get_all_math_dtypes(device):
469:                for dt2 in get_all_math_dtypes(device):
519:            for dt1 in get_all_math_dtypes(device):
520:                for dt2 in get_all_math_dtypes(device):
528:        for dt in get_all_math_dtypes(device):
561:        for dtype in get_all_dtypes():
766:                                          dtypes=get_all_math_dtypes(device))
771:                                          dtypes=get_all_math_dtypes(device))
782:                                          dtypes=get_all_math_dtypes(device))
879:        dtypes = get_all_dtypes(include_bfloat16=False)
898:        dtypes = get_all_dtypes(include_bfloat16=False, include_bool=False)
965:    dtypesIfCUDA(*itertools.product(get_all_dtypes(include_bfloat16=False, include_complex=False),
966:                                     get_all_dtypes(include_bfloat16=False, include_complex=False)))
967:    dtypes(*itertools.product(get_all_dtypes(include_half=False, include_bfloat16=False,
969:                               get_all_dtypes(include_half=False, include_bfloat16=False,
976:            return dtype in get_all_int_dtypes() + [torch.bool]
979:            return dtype in get_all_fp_dtypes(include_half=True, include_bfloat16=False)
```

</p>
</details>

<details>
<summary>

`test/test_unary_ufuncs.py`</summary>

<p>

```python
24:    floating_types_and, all_types_and_complex_and, floating_and_complex_types_and, get_all_dtypes, get_all_math_dtypes,
25:    get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes
517:    dtypes(*(get_all_int_dtypes() + [torch.bool] +
518:              get_all_fp_dtypes(include_bfloat16=False)))
596:    dtypes(*get_all_fp_dtypes(include_half=True, include_bfloat16=False))
611:        invalid_input_dtypes = get_all_int_dtypes() + \
612:            get_all_complex_dtypes() + \
619:        for dtype in get_all_fp_dtypes(include_half=True, include_bfloat16=False):
1048:    dtypes(*get_all_math_dtypes('cpu'))
1182:    dtypesIfCUDA(*get_all_fp_dtypes())
1190:    dtypesIfCUDA(*get_all_fp_dtypes())
1205:    dtypesIfCUDA(*get_all_fp_dtypes())
1215:    dtypesIfCUDA(*get_all_fp_dtypes())
1307:    dtypes(*(get_all_dtypes(include_bool=False)))
1349:    dtypes(*(get_all_fp_dtypes(include_half=False) +
1350:              get_all_complex_dtypes()))
1351:    dtypesIfCUDA(*(get_all_fp_dtypes(include_half=True) +
1352:                    get_all_complex_dtypes()))
```

</p>
</details>

<details>
<summary>

`test/test_view_ops.py`</summary>

<p>

```python
19:    get_all_dtypes, get_all_int_dtypes, get_all_fp_dtypes, get_all_complex_dtypes
124:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
131:    dtypes(*get_all_dtypes(include_bfloat16=False))
213:            for view_dtype in [*get_all_fp_dtypes(), *get_all_complex_dtypes()]:
220:    dtypes(*get_all_dtypes())
224:        for view_dtype in get_all_dtypes():
305:    dtypes(*get_all_complex_dtypes(include_complex32=True))
343:    dtypes(*get_all_dtypes())
354:    dtypes(*get_all_dtypes())
364:    dtypes(*get_all_dtypes())
374:    dtypes(*get_all_dtypes())
384:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes()))
395:    dtypes(*get_all_complex_dtypes())
426:    dtypes(*get_all_complex_dtypes())
451:    dtypes(*product(get_all_complex_dtypes(), get_all_dtypes()))
1263:    dtypes(*(torch.testing.get_all_dtypes()))
1279:    dtypes(*(torch.testing.get_all_dtypes()))
1405:    dtypes(*(get_all_int_dtypes() + get_all_fp_dtypes(include_bfloat16=False) +
1406:              get_all_complex_dtypes()))
1471:    dtypes(*get_all_dtypes(include_bfloat16=False))
1574:    dtypes(*get_all_dtypes())
1601:    dtypes(*get_all_dtypes(include_bfloat16=False))
1632:    dtypes(*get_all_dtypes(include_bfloat16=False))
1711:        for dt in get_all_dtypes():
1717:        for dt in get_all_dtypes():
1724:        for dt in get_all_dtypes():
```

</p>
</details>

I'm looking forward to your viewpoints. Thanks :)

cc: mruberry kshitij12345 anjali411

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71561

Reviewed By: samdow

Differential Revision: D34856571

Pulled By: mruberry

fbshipit-source-id: 0dca038bcad5cf69906245c496d2e61ac3876335
(cherry picked from commit b058f67b4313143efa714ab105f36e74083131b9)
2022-03-15 20:31:41 +00:00
Duncan Hill
0988dc481a [Codemod][Codemod deprecated unittest asserts] fbcode//caffe2/test (#71708)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71708

In Python 3.2, a number of asserts were deprecated.

In Python 3.11, these asserts are deleted completely. The files in this change still use the deprecated asserts.

Switch over to the supported syntax for 3.2 onwards.

Test Plan: Tested on the internal test suite runner.

Reviewed By: ajtulloch

Differential Revision: D33503694

fbshipit-source-id: a150f296033260acf8365d77b837ce0679f57361
(cherry picked from commit abf60ed97409265222915d8265aaabedd625fd93)
2022-03-15 19:28:52 +00:00
Joel Benjamin Schlosser
30653d164d Fix serialization and deepcopying for wrapper subclasses
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73078
2022-02-24 18:21:25 +00:00
Kurt Mohler
8e7fe87630 Rename Typed/UntypedStorage to _Typed/_UntypedStorage (#72540)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72540

Reviewed By: jbschlosser

Differential Revision: D34216823

Pulled By: bdhirsh

fbshipit-source-id: 1bc9930ab582771ebf02308e035576cd1a0dbe47
(cherry picked from commit 329238f612)
2022-02-15 23:53:01 +00:00
Christian Puhrsch
4a7e07e53e Fix torch.save and detach for CSR Tensor (#71963)
Summary:
Currently saving a CSR Tensor simply fails. This also addresses the segfault encountered in https://github.com/pytorch/pytorch/issues/71652.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71963

Reviewed By: jbschlosser

Differential Revision: D33895938

Pulled By: cpuhrsch

fbshipit-source-id: a333505d3a216705147c2aaaaeb2a0fd0c2a5e43
(cherry picked from commit a88265921c)
2022-02-02 23:59:24 +00:00
Kurt Mohler
b69155f754 Avoid dtype mismatch error in torch.save if storages are unallocated (#68787)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/58970

cc mruberry

Pull Request resolved: https://github.com/pytorch/pytorch/pull/68787

Reviewed By: mruberry

Differential Revision: D32617425

Pulled By: anjali411

fbshipit-source-id: fe7f2374e4ef4428346a0a202cae8e0d382e03ab
2021-11-24 09:51:29 -08:00
Kurt Mohler
bc3d380ed1 Throw error when saving storages that view same data with different type (#66949)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/58970

cc mruberry

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66949

Reviewed By: albanD

Differential Revision: D31926323

Pulled By: anjali411

fbshipit-source-id: f6e7acc0c1968b70a94f9b0b69a32780e8e21a62
2021-11-16 08:44:44 -08:00
Jane Xu
b07371f19c [skip ci] Set test owners for serialization tests (#66862)
Summary:
Action following https://github.com/pytorch/pytorch/issues/66232

cc mruberry

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66862

Reviewed By: saketh-are

Differential Revision: D31828615

Pulled By: janeyx99

fbshipit-source-id: 8d28970eead9d6f26e9ea64b823295d9c9e1469d
2021-10-21 13:22:18 -07:00
Kurt Mohler
5883523c1d Remove dtype from torch.Storage and use only torch.ByteStorage (#62030)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62030

Remove dtype tracking from Python Storage interface, remove all the different `<type>Storage` classes except for `ByteStorage`, and update serialization accordingly, while maintaining as much FC/BC as possible

Fixes https://github.com/pytorch/pytorch/issues/47442

* **THE SERIALIZATION FORMAT IS FULLY FC/BC.** We worked very hard to make sure this is the case. We will probably want to break FC at some point to make the serialization structure of tensors make more sense, but not today.
* There is now only a single torch.ByteStorage class. Methods like `Tensor.set_` no longer check that the dtype of storage is appropriate.
* As we no longer know what dtype of a storage is, we've **removed** the size method from Storage, replacing it with nbytes. This is to help catch otherwise silent errors where you confuse number of elements with number of bytes.
* `Storage._new_shared` takes a `nbytes` kwarg and will reject previous positional only calls.  `Storage._new_with_file` and `_set_from_file` require explicit element size arguments.
* It's no longer possible to convert storages to different types using the float/double/etc methods. Instead, do the conversion using a tensor.
* It's no longer possible to allocate a typed storage directly using FloatStorage/DoubleStorage/etc constructors. Instead, construct a tensor and extract its storage. The classes still exist but they are used purely for unpickling.
* The preexisting serialization format stores dtype with storage, and in fact this dtype is used to determine the dtype of the tensor overall.
 To accommodate this case, we introduce a new TypedStorage concept that exists only during unpickling time which is used to temporarily store the dtype so we can construct a tensor. **If you overrode the handling of pickling/unpickling, you MUST add handling for TypedStorage** or your serialization code will degrade to standard file-based serialization.

Original pull request: https://github.com/pytorch/pytorch/pull/59671

Reviewed By: soulitzer, ngimel

Differential Revision: D29466819

Pulled By: ezyang

fbshipit-source-id: 4a14e5d3c2b08e06e558683d97f7378a3180b00e
2021-10-05 13:50:34 -07:00
Alban Desmaison
7c62b6e973 add deepcopy support to subclasses (#65584)
Summary:
Happy to get any feedback on how to make this code cleaner!

This:
- Fix Tensor attribute deepcopy BC-breaking?
- Add a test for Tensor attribute deepcopy
- Fix subclass deepcopy
- Moves the subclass serialization tests into their own class not to interfere with other serialization test logic
- Add a test for subclass deepcopy

cc ezyang gchanan

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65584

Reviewed By: gchanan

Differential Revision: D31206590

Pulled By: albanD

fbshipit-source-id: 74a8f0767f4933b9c941fbea880a8fd1b893ea2f
2021-09-27 14:36:22 -07:00
Shen Li
1022443168 Revert D30279364: [codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: revert-hammer

Differential Revision:
D30279364 (b004307252)

Original commit changeset: c1ed77dfe43a

fbshipit-source-id: eab50857675c51e0088391af06ec0ecb14e2347e
2021-08-12 11:45:01 -07:00
Zsolt Dollenstein
b004307252 [codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: manual inspection & sandcastle

Reviewed By: zertosh

Differential Revision: D30279364

fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
2021-08-12 10:58:35 -07:00
Alban Desmaison
e6a227465b Add serialization support for slots and subclass getstate/setstate (#62745)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62745

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D30113112

Pulled By: albanD

fbshipit-source-id: 6c562d0c060fb0280e5e3d432bb42fb833e6d500
2021-08-05 06:49:44 -07:00
Edward Yang
cf1f59452b Hacky support for meta tensor serialization. (#62192)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62192

This support is hacky because it doesn't preserve meta tensor storage
sharing (e.g., if you serialize a model with shared storage, e.g., a
tensor and a view on a tensor, when I deserialize the viewing
relationship will be broken and these are just different tensors.) The
hack is also durable, in the sense that we will be on the hook for
supporting `_rebuild_meta_tensor_no_storage` in perpetuity in the
future, even if we change our mind about the serialization format.

This unblocks an FB production use case. I didn't add C++ support to minimize
blast area of this patch.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D29910535

Pulled By: ezyang

fbshipit-source-id: d98dcdd0108dfc3ae730a071d3c583b6d0281d21
2021-07-26 14:33:45 -07:00