Commit Graph

1481 Commits

Author SHA1 Message Date
Mikayla Gawarecki
a2d4fea872 [easy] Move state_dict hooks tests to test_module_hooks and decorate tests that call load_state_dict with swap (#126906)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126906
Approved by: https://github.com/albanD
2024-06-10 21:50:17 +00:00
laithsakka
3a2d0755a4 enable test_ParameterList with dynamo if nn module inlining enabled only (#128308)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128308
Approved by: https://github.com/anijain2305
2024-06-10 21:25:40 +00:00
Mikayla Gawarecki
65aa16f968 Revert "Default XLA to use swap_tensors path in nn.Module._apply (#126814)" (#128170)
https://github.com/pytorch/pytorch/issues/128165 :(

This reverts commit a7b1dd82ff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128170
Approved by: https://github.com/drisspg, https://github.com/albanD
2024-06-07 01:44:14 +00:00
Mikayla Gawarecki
a7b1dd82ff Default XLA to use swap_tensors path in nn.Module._apply (#126814)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126814
Approved by: https://github.com/JackCaoG, https://github.com/albanD
ghstack dependencies: #127313
2024-06-04 21:40:49 +00:00
PyTorch MergeBot
17dea09b15 Revert "Default XLA to use swap_tensors path in nn.Module._apply (#126814)"
This reverts commit bfdec93395.

Reverted https://github.com/pytorch/pytorch/pull/126814 on behalf of https://github.com/izaitsevfb due to suspicious build instructions count regression, see [D58015016](https://www.internalfb.com/diff/D58015016) ([comment](https://github.com/pytorch/pytorch/pull/126814#issuecomment-2143545818))
2024-06-01 18:46:16 +00:00
Mikayla Gawarecki
bfdec93395 Default XLA to use swap_tensors path in nn.Module._apply (#126814)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126814
Approved by: https://github.com/JackCaoG, https://github.com/albanD
ghstack dependencies: #127313
2024-05-30 18:28:13 +00:00
Mikayla Gawarecki
cd06ae0cb8 Relax use_count constraints for swap_tensors when AccumulateGrad holds a reference (#127313)
### Before this PR:
`torch.utils.swap_tensors(a, b)` required the `use_count` of `a` and `b` to be 1

```python
a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here would fail due to the reference held by AccumulateGrad node, which is not cleaned up after backward
# torch.utils.swap_tensors(a, b)
del out
# Calling swap_tensors here would pass
torch.utils.swap_tensors(a, b)
```
### After this PR:
`torch.utils.swap_tensors(a, b)` requires the `use_count` of `a` and `b` to be 1 or 2 IF the second reference is held by `AccumulateGrad`

A pre-hook will be registered on the `AccumulateGrad` node so that it will fail if it is called (i.e. if user attempts to backward through the graph).

```python
a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here is ok
torch.utils.swap_tensors(a, b)
# If we ever backward to the AccumulateGrad node it will error that it was poisoned by swap_tensors
```

### Application to `nn.Module`

This issue is especially pertinent in context of `nn.Module` where parameters will have `AccumulateGrad` nodes initialized after forward. Specifically, this is intended to address https://github.com/pytorch/pytorch/pull/126814#issuecomment-2127777866. Previously, this would fail at the `m.cpu()` but we want users to be able to do something like the following, and instead raise an error if the user ever attempts to backward through the poisoned `AccumulateGrad` node

```python
import torch
import torch.nn as nn
m = nn.Linear(3, 5)
inp = torch.randn(2, 3)
out = m(inp)
out.sum().backward()
m.cpu()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127313
Approved by: https://github.com/soulitzer
2024-05-30 07:06:55 +00:00
JackCaoG
38a33c3202 don't call .item in onehot for XLA (#127335)
We found that `nn.function.one_hot` will cause a graph break due to the item call in the native implementation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127335
Approved by: https://github.com/ezyang
2024-05-29 20:37:26 +00:00
dan_the_3rd
c133665d4a [CUDA] Parallelize upsampling OPS across the batch/channel dimension. (#127082)
This can make this operation 200x+ faster on modern GPUs for small grid sizes, as otherwise this kernel is scheduled with a single block (!)

Tested on A100 with:
```
python test/test_nn.py TestNNDeviceTypeCUDA
```

**Benchmarks FW**
Ran on A100 / bf16
## Forward pass benchmarks

| batch size | input size | output size | before runtime (mem bandwidth) | after runtime (mem bandwidth) | speedup |
|------------|------------|-------------|------------------|-----------------|---------|
| 768 | 16x16 | 6x6 | 5855us (0.07 GB/s) | 38us (10 GB/s) | 154x |
| 768 | 16x16 | 7x7 | 5214us (0.08 GB/s) | 37us (11 GB/s) | 138x |
| 768 | 16x16 | 14x14 | 2314us (0.27 GB/s) | 36us (17 GB/s) | 63x |
| 768 | 16x16 | 16x16 | 1232us (0.59 GB/s) | 33us (21 GB/s) | 36x |
| 768 | 32x32 | 6x6 | 19442us (0.07 GB/s) | 98us (15 GB/s) | 197x |
| 768 | 32x32 | 7x7 | 16918us (0.09 GB/s) | 89us (17 GB/s) | 188x |
| 768 | 32x32 | 14x14 | 6023us (0.28 GB/s) | 69us (25 GB/s) | 86x |
| 768 | 32x32 | 16x16 | 3455us (0.52 GB/s) | 55us (32 GB/s) | 62x |
| 768 | 48x48 | 6x6 | 38597us (0.08 GB/s) | 179us (18 GB/s) | 214x |
| 768 | 48x48 | 7x7 | 34700us (0.09 GB/s) | 163us (20 GB/s) | 211x |
| 768 | 48x48 | 14x14 | 10647us (0.33 GB/s) | 112us (31 GB/s) | 94x |
| 768 | 48x48 | 16x16 | 7388us (0.49 GB/s) | 100us (36 GB/s) | 73x |
| 768 | 64x64 | 6x6 | 76288us (0.07 GB/s) | 310us (19 GB/s) | 246x |
| 768 | 64x64 | 7x7 | 54981us (0.1 GB/s) | 257us (23 GB/s) | 213x |
| 768 | 64x64 | 14x14 | 16565us (0.37 GB/s) | 169us (36 GB/s) | 97x |
| 768 | 64x64 | 16x16 | 12037us (0.51 GB/s) | 141us (43 GB/s) | 84x |
| 1024 | 16x16 | 6x6 | 8123us (0.06 GB/s) | 44us (12 GB/s) | 183x |
| 1024 | 16x16 | 7x7 | 7017us (0.08 GB/s) | 45us (12 GB/s) | 155x |
| 1024 | 16x16 | 14x14 | 3150us (0.27 GB/s) | 45us (18 GB/s) | 69x |
| 1024 | 16x16 | 16x16 | 1695us (0.57 GB/s) | 41us (23 GB/s) | 40x |
| 1024 | 32x32 | 6x6 | 25918us (0.07 GB/s) | 120us (16 GB/s) | 214x |
| 1024 | 32x32 | 7x7 | 22622us (0.09 GB/s) | 108us (18 GB/s) | 208x |
| 1024 | 32x32 | 14x14 | 8245us (0.28 GB/s) | 87us (26 GB/s) | 94x |
| 1024 | 32x32 | 16x16 | 4599us (0.53 GB/s) | 68us (35 GB/s) | 67x |
| 1024 | 48x48 | 6x6 | 51486us (0.08 GB/s) | 219us (20 GB/s) | 234x |
| 1024 | 48x48 | 7x7 | 46501us (0.09 GB/s) | 202us (22 GB/s) | 229x |
| 1024 | 48x48 | 14x14 | 14280us (0.33 GB/s) | 145us (32 GB/s) | 98x |
| 1024 | 48x48 | 16x16 | 9877us (0.49 GB/s) | 125us (39 GB/s) | 79x |
| 1024 | 64x64 | 6x6 | 101731us (0.07 GB/s) | 378us (20 GB/s) | 268x |
| 1024 | 64x64 | 7x7 | 73465us (0.1 GB/s) | 320us (24 GB/s) | 229x |
| 1024 | 64x64 | 14x14 | 22109us (0.37 GB/s) | 218us (37 GB/s) | 101x |
| 1024 | 64x64 | 16x16 | 16081us (0.51 GB/s) | 178us (46 GB/s) | 90x |
| 1536 | 16x16 | 6x6 | 12546us (0.06 GB/s) | 61us (13 GB/s) | 205x |
| 1536 | 16x16 | 7x7 | 11064us (0.07 GB/s) | 63us (13 GB/s) | 175x |
| 1536 | 16x16 | 14x14 | 4839us (0.26 GB/s) | 62us (20 GB/s) | 77x |
| 1536 | 16x16 | 16x16 | 2630us (0.55 GB/s) | 59us (24 GB/s) | 44x |
| 1536 | 32x32 | 6x6 | 38898us (0.07 GB/s) | 170us (17 GB/s) | 227x |
| 1536 | 32x32 | 7x7 | 34079us (0.09 GB/s) | 155us (19 GB/s) | 219x |
| 1536 | 32x32 | 14x14 | 12632us (0.27 GB/s) | 124us (28 GB/s) | 101x |
| 1536 | 32x32 | 16x16 | 6900us (0.53 GB/s) | 98us (37 GB/s) | 70x |
| 1536 | 48x48 | 6x6 | 77272us (0.08 GB/s) | 316us (21 GB/s) | 243x |
| 1536 | 48x48 | 7x7 | 70153us (0.09 GB/s) | 291us (23 GB/s) | 240x |
| 1536 | 48x48 | 14x14 | 21500us (0.33 GB/s) | 208us (34 GB/s) | 103x |
| 1536 | 48x48 | 16x16 | 14851us (0.49 GB/s) | 181us (40 GB/s) | 81x |
| 1536 | 64x64 | 6x6 | 152669us (0.07 GB/s) | 548us (21 GB/s) | 278x |
| 1536 | 64x64 | 7x7 | 110348us (0.1 GB/s) | 466us (25 GB/s) | 236x |
| 1536 | 64x64 | 14x14 | 33350us (0.36 GB/s) | 316us (38 GB/s) | 105x |
| 1536 | 64x64 | 16x16 | 24173us (0.51 GB/s) | 263us (47 GB/s) | 91x |
| 4096 | 16x16 | 6x6 | 34638us (0.06 GB/s) | 138us (16 GB/s) | 249x |
| 4096 | 16x16 | 7x7 | 31590us (0.07 GB/s) | 144us (16 GB/s) | 218x |
| 4096 | 16x16 | 14x14 | 13203us (0.26 GB/s) | 149us (23 GB/s) | 88x |
| 4096 | 16x16 | 16x16 | 7328us (0.53 GB/s) | 143us (27 GB/s) | 51x |
| 4096 | 32x32 | 6x6 | 103802us (0.07 GB/s) | 405us (19 GB/s) | 256x |
| 4096 | 32x32 | 7x7 | 91354us (0.08 GB/s) | 372us (22 GB/s) | 245x |
| 4096 | 32x32 | 14x14 | 34501us (0.26 GB/s) | 312us (29 GB/s) | 110x |
| 4096 | 32x32 | 16x16 | 18465us (0.52 GB/s) | 247us (39 GB/s) | 74x |
## Backward pass benchmarks

| batch size | input size | output size | before runtime (mem bandwidth) | after runtime (mem bandwidth) | speedup |
|------------|------------|-------------|------------------|-----------------|---------|
| 768 | 16x16 | 6x6 | 78656us (0.0 GB/s) | 323us (1 GB/s) | 243x |
| 768 | 16x16 | 7x7 | 67167us (0.0 GB/s) | 292us (1 GB/s) | 230x |
| 768 | 16x16 | 14x14 | 27478us (0.02 GB/s) | 229us (2 GB/s) | 119x |
| 768 | 16x16 | 16x16 | 131us (5.59 GB/s) | 56us (13 GB/s) | 2x |
| 768 | 32x32 | 6x6 | 271752us (0.0 GB/s) | 888us (1 GB/s) | 305x |
| 768 | 32x32 | 7x7 | 224110us (0.0 GB/s) | 813us (1 GB/s) | 275x |
| 768 | 32x32 | 14x14 | 85365us (0.02 GB/s) | 450us (3 GB/s) | 189x |
| 768 | 32x32 | 16x16 | 67700us (0.02 GB/s) | 360us (5 GB/s) | 187x |
| 768 | 48x48 | 6x6 | 593709us (0.0 GB/s) | 1988us (1 GB/s) | 298x |
| 768 | 48x48 | 7x7 | 485566us (0.0 GB/s) | 1694us (1 GB/s) | 286x |
| 768 | 48x48 | 14x14 | 164059us (0.02 GB/s) | 897us (3 GB/s) | 182x |
| 768 | 48x48 | 16x16 | 134317us (0.02 GB/s) | 674us (5 GB/s) | 199x |
| 768 | 64x64 | 6x6 | 1026651us (0.0 GB/s) | 3360us (1 GB/s) | 305x |
| 768 | 64x64 | 7x7 | 770901us (0.0 GB/s) | 2584us (2 GB/s) | 298x |
| 768 | 64x64 | 14x14 | 277850us (0.02 GB/s) | 1556us (3 GB/s) | 178x |
| 768 | 64x64 | 16x16 | 236245us (0.02 GB/s) | 1144us (5 GB/s) | 206x |
| 1024 | 16x16 | 6x6 | 106638us (0.0 GB/s) | 341us (1 GB/s) | 312x |
| 1024 | 16x16 | 7x7 | 90886us (0.0 GB/s) | 314us (1 GB/s) | 288x |
| 1024 | 16x16 | 14x14 | 36572us (0.02 GB/s) | 292us (2 GB/s) | 124x |
| 1024 | 16x16 | 16x16 | 171us (5.69 GB/s) | 56us (17 GB/s) | 3x |
| 1024 | 32x32 | 6x6 | 356900us (0.0 GB/s) | 936us (2 GB/s) | 380x |
| 1024 | 32x32 | 7x7 | 299139us (0.0 GB/s) | 870us (2 GB/s) | 343x |
| 1024 | 32x32 | 14x14 | 113205us (0.02 GB/s) | 576us (4 GB/s) | 196x |
| 1024 | 32x32 | 16x16 | 90886us (0.02 GB/s) | 458us (5 GB/s) | 198x |
| 1024 | 48x48 | 6x6 | 786896us (0.0 GB/s) | 2127us (2 GB/s) | 369x |
| 1024 | 48x48 | 7x7 | 640515us (0.0 GB/s) | 1837us (2 GB/s) | 348x |
| 1024 | 48x48 | 14x14 | 218720us (0.02 GB/s) | 1152us (4 GB/s) | 189x |
| 1024 | 48x48 | 16x16 | 178827us (0.02 GB/s) | 863us (5 GB/s) | 207x |
| 1024 | 64x64 | 6x6 | 1379991us (0.0 GB/s) | 3589us (2 GB/s) | 384x |
| 1024 | 64x64 | 7x7 | 1047466us (0.0 GB/s) | 2774us (2 GB/s) | 377x |
| 1024 | 64x64 | 14x14 | 370139us (0.02 GB/s) | 1999us (4 GB/s) | 185x |
| 1024 | 64x64 | 16x16 | 316501us (0.02 GB/s) | 1470us (5 GB/s) | 215x |
| 1536 | 16x16 | 6x6 | 159057us (0.0 GB/s) | 477us (1 GB/s) | 332x |
| 1536 | 16x16 | 7x7 | 135578us (0.0 GB/s) | 441us (1 GB/s) | 306x |
| 1536 | 16x16 | 14x14 | 53002us (0.02 GB/s) | 400us (3 GB/s) | 132x |
| 1536 | 16x16 | 16x16 | 252us (5.79 GB/s) | 55us (26 GB/s) | 4x |
| 1536 | 32x32 | 6x6 | 545653us (0.0 GB/s) | 1323us (2 GB/s) | 412x |
| 1536 | 32x32 | 7x7 | 447491us (0.0 GB/s) | 1248us (2 GB/s) | 358x |
| 1536 | 32x32 | 14x14 | 173491us (0.02 GB/s) | 787us (4 GB/s) | 220x |
| 1536 | 32x32 | 16x16 | 136395us (0.02 GB/s) | 633us (5 GB/s) | 215x |
| 1536 | 48x48 | 6x6 | 1198639us (0.0 GB/s) | 3057us (2 GB/s) | 392x |
| 1536 | 48x48 | 7x7 | 985549us (0.0 GB/s) | 2645us (2 GB/s) | 372x |
| 1536 | 48x48 | 14x14 | 331419us (0.02 GB/s) | 1581us (4 GB/s) | 209x |
| 1536 | 48x48 | 16x16 | 270972us (0.02 GB/s) | 1186us (6 GB/s) | 228x |
| 1536 | 64x64 | 6x6 | 2094282us (0.0 GB/s) | 5214us (2 GB/s) | 401x |
| 1536 | 64x64 | 7x7 | 1593449us (0.0 GB/s) | 4086us (2 GB/s) | 389x |
| 1536 | 64x64 | 14x14 | 559244us (0.02 GB/s) | 2828us (4 GB/s) | 197x |
| 1536 | 64x64 | 16x16 | 469471us (0.02 GB/s) | 2057us (6 GB/s) | 228x |
| 4096 | 16x16 | 6x6 | 430494us (0.0 GB/s) | 1008us (2 GB/s) | 427x |
| 4096 | 16x16 | 7x7 | 360346us (0.0 GB/s) | 1015us (2 GB/s) | 354x |
| 4096 | 16x16 | 14x14 | 142868us (0.02 GB/s) | 988us (3 GB/s) | 144x |
| 4096 | 16x16 | 16x16 | 658us (5.93 GB/s) | 56us (69 GB/s) | 11x |
| 4096 | 32x32 | 6x6 | 1425928us (0.0 GB/s) | 2796us (2 GB/s) | 509x |
| 4096 | 32x32 | 7x7 | 1188862us (0.0 GB/s) | 2906us (2 GB/s) | 409x |
| 4096 | 32x32 | 14x14 | 464286us (0.02 GB/s) | 1965us (4 GB/s) | 236x |
| 4096 | 32x32 | 16x16 | 363903us (0.02 GB/s) | 1588us (6 GB/s) | 229x |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127082
Approved by: https://github.com/fmassa
2024-05-24 21:17:12 +00:00
PyTorch MergeBot
b36e390b6c Revert "Default XLA to use swap_tensors path in nn.Module._apply (#126814)"
This reverts commit eb41ed5d90.

Reverted https://github.com/pytorch/pytorch/pull/126814 on behalf of https://github.com/mikaylagawarecki due to broke xla ci ([comment](https://github.com/pytorch/pytorch/pull/126814#issuecomment-2127719337))
2024-05-23 17:43:06 +00:00
Mikayla Gawarecki
eb41ed5d90 Default XLA to use swap_tensors path in nn.Module._apply (#126814)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126814
Approved by: https://github.com/JackCaoG, https://github.com/albanD
2024-05-23 15:43:32 +00:00
Simon Fan
7530cfe7e4 [dynamo][flaky tests] test_conv_empty_input_* (#126790)
Run CI, maybe fixes https://github.com/pytorch/pytorch/issues/126178

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126790
Approved by: https://github.com/mikaylagawarecki
2024-05-22 03:14:21 +00:00
Martim Mendes
d54c28e7fc Added error checks for invalid inputs on thnn_conv2d (#121906)
Fixes #121188
Prevent Segmentation Fault in 'torch._C._nn.thnn_conv2d'

Previously, calling 'torch._C._nn.thnn_conv2d' with invalid arguments for padding, stride, and kernel_size would result in a segmentation fault. This issue has been resolved by implementing argument validation (using Torch Check). Now, when invalid arguments are detected, a runtime error is raised with a debug message detailing the correct format.

Additionally, this commit includes tests to cover the three referenced cases.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121906
Approved by: https://github.com/janeyx99
2024-05-17 23:41:48 +00:00
vfdev-5
415a8f6398 Fixed issue in affine_grid_backward when grad_grid is non-contiguous (#124370)
Description:
- replaced .view with .reshape to fix the problem when grad_grid is channels last 2d/3d
- added a consistency test

Fixes #124154

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124370
Approved by: https://github.com/lezcano
2024-04-18 16:30:10 +00:00
Xuehai Pan
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
Aaron Gokaslan
1d6c5972c1 [BE]: Optimize min/max/sum comprehensions C419 (#123960)
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960
Approved by: https://github.com/malfet
2024-04-12 23:54:15 +00:00
eqy
2f0fc04fa3 [CUDA][64-bit indexing] Bump large tensor threshold of test_cross_entropy_large_tensor to 70GiB (#123772)
`torch.cuda.max_memory_reserved()` here shows 68729962496 (about 65546 MiB).

CC @malfet @crcrpar

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123772
Approved by: https://github.com/mikaylagawarecki
2024-04-12 19:18:20 +00:00
David Yan
63c24f73ef Upsample2d backwards to int64_t (#123682)
Summary: To unblock training where upsamplenearest2d involves input or output tensors which are larger than 2^31. Comes up frequently in image & video applications.

Test Plan:
```
buck2 test mode/opt //caffe2/test:test_nn_cuda -- test_upsamplingnearest2d_backward_64bit_indexing
```

Benchmarking (N5207417):
```
device_ms, cpu_ms, gb/device_ms*1000
# before changes
118.03993721008301 124.09385920000001 98.72685525972494
# after changes
118.05780944824218 124.10893509999994 98.71190944734577
```

Differential Revision: D55625666

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123682
Approved by: https://github.com/ezyang
2024-04-10 20:26:08 +00:00
Tailing Yuan
041be901b3 fix ctc_loss zero-length/neg-length corner cases (#123193)
Fixes #84827, fixes #86596, fixes #88047, fixes #89208.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123193
Approved by: https://github.com/mikaylagawarecki
2024-04-09 20:39:39 +00:00
eqy
d5b5012dc4 [CUDA] Raise softmax_forward_64bit_indexing GPU memory requirement (#116075)
printing `torch.cuda.memory_summary()` shows ~41GiB reserved at the end of this test, not sure how it was passing previously on CUDA.

CC @ptrblck @malfet

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116075
Approved by: https://github.com/ptrblck, https://github.com/malfet
2024-03-21 00:03:17 +00:00
David Yan
6915a5be70 Increase numel limit to 2^63 for replicatepad1d (#122199)
Summary: As title

Test Plan:
```
CUDA_VISIBLE_DEVICES=5 buck2 test mode/opt //caffe2/test:test_nn_cuda -- test_replicatepad_64bit_indexing
```

Also benchmarked in N5106027
```
device_ms, cpu_ms, gb/device_ms*1000
# before changes
11.058772478103638 18.912256770000006 735.4118906278957
# after changes
10.621162576675415 18.58972748 765.7121070725207
```

Differential Revision: D55030372

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122199
Approved by: https://github.com/ezyang
2024-03-19 21:55:34 +00:00
jmarin
a2854ae904 Bugfix consume_prefix_in_state_dict_if_present function to keep the order of the state_dict (#117464)
This PR proposes to keep the original order as the original state_dict, as the issue creator proposed. It also removes a bug concerning how ``_metadata`` is handled (see below), as well as other small changes to properly remove the prefix when is present.

In the original code, ``_metadata`` was handled as a ``key``.

```
    # also strip the prefix in metadata if any.
    if "_metadata" in state_dict:
```

This is not the case, ``_metadata`` is actually an ``attribute``. Hence, the previous condition is changed to:

```
    # also strip the prefix in metadata if any.
    if hasattr(state_dict, "_metadata"):
```

This PR also includes the necessary test.

Fixes #106942

Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117464
Approved by: https://github.com/mikaylagawarecki
2024-03-07 04:00:49 +00:00
Eddie Yan
967dd31621 [cuDNN] Cleanup cuDNN < 8.1 ifdefs (#120862)
Follow-up of #95722

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120862
Approved by: https://github.com/Skylion007
2024-03-07 01:46:25 +00:00
Lourencom
69cedc16c5 Add padding dimension checks and tests (#121298)
Fixes #121093

Previously, calling the following functions with invalid padding dimensions would cause a segmentation fault:
```
torch._C._nn.replication_pad1d, torch._C._nn.replication_pad3d, torch._C._nn.replication_pad3d
```

To fix, added condition checking to raise a runtime error with a debug message instead, specifying the correct dimensions necessary.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121298
Approved by: https://github.com/mikaylagawarecki
2024-03-06 21:55:34 +00:00
lancerts
099ff51d45 torch check the division by zero in batch_norm_update_stats (#120882)
Fixes #120803

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120882
Approved by: https://github.com/CaoE, https://github.com/malfet
2024-03-06 05:40:21 +00:00
mingfeima
34e3f6f3c9 fix segfault in torch.native_channel_shuffle when input is empty (#121199)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

fix https://github.com/pytorch/pytorch/issues/121092

`torch.channel_shuffle` could handle empty inputs correctly. `torch.native_channel_shuffle` bypassed the `numel == 0` check, this causes divided by zero in underlying kernel.

* __->__ #121199

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121199
Approved by: https://github.com/malfet
2024-03-06 00:46:36 +00:00
yuanx749
e317e39a02 Fix nonlinearity arg issue in RNN (#120234)
Fixes #114617

This PR fix the the issue with `nonlinearity`, so that it can be passed as arg or kwarg.

Alternatively, if making `nonlinearity` kwarg-only is preferred, I can revert to another commit. cc @mikaylagawarecki
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120234
Approved by: https://github.com/mikaylagawarecki
2024-02-28 20:53:18 +00:00
Tobias Ringwald
d9a1b25807 Fixed an issue where nn.Linear would cause an internal int underflow … (#119221)
…when trying to reshape a scalar input.

Fixes #119161

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119221
Approved by: https://github.com/albanD
2024-02-08 21:06:34 +00:00
lancerts
b51b27922b Add to_empty() suggestion in the error message (#119353)
Fixes #119293, the comprehensive documentation is [here](0f478d9d61/docs/source/meta.rst (id11)).
Just added the suggestion into the error message so it is more informative to user.

@albanD

Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119353
Approved by: https://github.com/mikaylagawarecki
2024-02-08 18:30:02 +00:00
Mikayla Gawarecki
d5a718d27b Add swap_tensors path to nn.Module._apply (#117167)
Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run*** if the autograd graph is still alive. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](6cf1fc66e3/torch/csrc/autograd/variable.cpp (L307)). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.

**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117167
Approved by: https://github.com/albanD
ghstack dependencies: #118028
2024-02-07 18:55:44 +00:00
Mikayla Gawarecki
b92819a039 Move nn.Module.load_state_dict tests from test_nn.py to separate file (#118028)
Move these tests out so in https://github.com/pytorch/pytorch/pull/117913 where we can to run these tests with both `torch.nn.utils.set_swap_module_params_on_conversion({True/False})`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118028
Approved by: https://github.com/albanD
2024-02-05 20:17:28 +00:00
Tobias Ringwald
2de327cedc Fixed an illegal memory access in cross entropy loss when using an index that is not a valid class (#117561)
…dex that is not a valid class.

Fixes #117532.

Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117561
Approved by: https://github.com/mikaylagawarecki
2024-02-02 11:03:16 +00:00
PyTorch MergeBot
df048f4da4 Revert "[RELAND] Remove deprecated fbgemm operators (#112153)"
This reverts commit 19e8ba95e5.

Reverted https://github.com/pytorch/pytorch/pull/112153 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/112153#issuecomment-1921965780))
2024-02-01 18:35:19 +00:00
vfdev-5
a1dd367716 Fixed error in bicubic upsampling aa=false for uint8 input (#118389)
Description:
- Fixed error in bicubic upsampling aa=false for uint8 input. This is seen in the test suite:
```diff
- self.assertLess(diff.max(), 15)
+ self.assertLess(diff.max(), 5)
```
While reducing the input range we do not fully remove the clipping effect that's why the threshold is 5 and not around 1.

- Renamed methods
- The error is mostly visible for upsampling (smaller -> larger) mode on the boundary values

More details on the bug:
For uint8 input and antialising=False we are using separable algorithm (using temp buffers and interpolating dimensions one by one) where interpolation weights and input indices are computed and stored using index ranges: `index_min` and `index_size`; weights outside of the `index_size` are zeros. For example, for an output point we can have index_min=10 and index_size=4 and 4 non-zero weights: so the output value is computed as
```
out_value = sum([src[i + index_min] * w for i, w in zip(range(4), weights) ])
```
When computing index ranges and weights for output points near the boundaries we should clamp `index_min` between 0 and input_size and `index_size` becomes smaller than 4. This approach is OK for antialiasing=True but is not correct for antialiasing=False where weights are computed incorrectly:
```
-- output index i= 0
regular float32 approach:
source indices: [-2, -1, 0, 1] -> outbounded values are clamped to boundaries -> [0, 0, 0, 1]
interp weights: [-0.07200000000000006, 0.4600000000000001, 0.72, -0.1080000000000001]

separable uint8 approach:
source indices coming from index ranges (min, size): [0, 1]
incorrect interp weights computed with current implementation : [1.1764705882352944, -0.17647058823529432, 0.0, 0.0]
fixed interp weights in the PR: [1.108, -0.1080000000000001, 0.0, 0.0]
Note: weight value corresponding to source index 0 is 1.108 = -0.07200000000000006 + 0.4600000000000001 + 0.72 and weight value corresponding to source index 1 is -0.1080000000000001 is the same as in f32 approach.
```

Quick benchmark to ensure perfs no regression:

```
[------------------------------------------------------------------------------------ Resize ------------------------------------------------------------------------------------]
                                                                               |  torch (2.3.0a0+gitfda85a6) PR  |  torch (2.3.0a0+git0d1e705) Nightly  |  Speed-up: PR vs Nightly
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_first bilinear (400, 400) -> (224, 224) aa=False  |        440.996 (+-2.044)        |          470.824 (+-5.927)           |      1.068 (+-0.000)
      3 torch.uint8 channels_first bicubic (400, 400) -> (224, 224) aa=False   |        463.565 (+-1.519)        |          497.231 (+-10.825)          |      1.073 (+-0.000)
      3 torch.uint8 channels_first bilinear (400, 400) -> (700, 700) aa=False  |       1717.000 (+-28.589)       |         1915.570 (+-43.397)          |      1.116 (+-0.000)
      3 torch.uint8 channels_first bicubic (400, 400) -> (700, 700) aa=False   |       1801.954 (+-22.391)       |         1981.501 (+-37.034)          |      1.100 (+-0.000)
      3 torch.uint8 channels_last bilinear (400, 400) -> (224, 224) aa=False   |        199.599 (+-0.851)        |          196.535 (+-3.788)           |      0.985 (+-0.000)
      3 torch.uint8 channels_last bicubic (400, 400) -> (224, 224) aa=False    |        243.126 (+-0.681)        |          240.695 (+-2.306)           |      0.990 (+-0.000)
      3 torch.uint8 channels_last bilinear (400, 400) -> (700, 700) aa=False   |        686.270 (+-2.870)        |          687.769 (+-17.863)          |      1.002 (+-0.000)
      3 torch.uint8 channels_last bicubic (400, 400) -> (700, 700) aa=False    |        899.509 (+-5.377)        |          899.063 (+-9.001)           |      1.000 (+-0.000)

Times are in microseconds (us).
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118389
Approved by: https://github.com/NicolasHug
ghstack dependencies: #118388
2024-02-01 14:14:32 +00:00
vfdev-5
eba4bd6b86 Updated test_upsamplingBiMode2d_consistency (#118388)
Description:
- Lowered error thresholds and added input range for bicubic to make visible the inconsistency error in the implementation for upsampling (smaller -> larger) bicubic aa=false mode for uint8 input dtype
- Updated out-dated comments
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118388
Approved by: https://github.com/NicolasHug
2024-02-01 09:22:23 +00:00
Peter Bell
19e8ba95e5 [RELAND] Remove deprecated fbgemm operators (#112153)
These operators are not used and have been deprecated since #72690
(Feb 2022).

BC-breaking message:

`TorchScript` models that were exported with the deprecated
`torch.jit.quantized` API will no longer be loadable, as the required
internal operators have been removed.
Please re-export your models using the newer `torch.ao.quantization` API
instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112153
Approved by: https://github.com/jerryzh168
2024-01-30 16:32:37 +00:00
Oguz Ulgen
3b38f7b266 Remove skips for passing tests (#118000)
These tests were already passing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118000
Approved by: https://github.com/yanboliang
2024-01-23 16:11:38 +00:00
PyTorch MergeBot
bb28965924 Revert "Remove skips for passing tests (#118000)"
This reverts commit 3c339b5b21.

Reverted https://github.com/pytorch/pytorch/pull/118000 on behalf of https://github.com/oulgen due to test passing on diff but failing on hud... ([comment](https://github.com/pytorch/pytorch/pull/118000#issuecomment-1905351752))
2024-01-23 06:10:25 +00:00
Oguz Ulgen
3c339b5b21 Remove skips for passing tests (#118000)
These tests were already passing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118000
Approved by: https://github.com/yanboliang
2024-01-23 03:41:23 +00:00
haozhe.zhu@intel.com
0ae952db76 enable mkldnn bf32 matmul (#116015)
### Testing
FP32 matmul vs. mkldnn BF32 matmul  on SPR

single core:

Input | BF32   / ms | FP32  /   ms | Speed up
-- | -- | -- | --
M: 128, N: 128, K: 128, trans_a: False, trans_b: False | 32.842 | 38.279 | 1.165
M: 128, N: 256, K: 128, trans_a: False, trans_b: False | 38.590 | 73.967 | 1.917
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 18456.267 | 74588.002 | 4.041

56 cores:
Input | BF32   / ms | FP32 /   ms | Speed up
-- | -- | -- | --
M: 8192, N: 768, K: 768, trans_a: False, trans_b: False | 1199.400 | 1715.548 | 1.430
M: 8192, N: 768, K: 768, trans_a: False, trans_b: True |1129.204 | 1708.912 |  1.513
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: False | 3655.915  | 7992.877 | 2.186
M: 8192, N: 768, K: 3072, trans_a: False, trans_b: True | 3707.993 |  8026.191 | 2.165
Batch: 768, M: 128, N: 64, K: 128  | 1296.419 | 1308.411 | 1.009

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116015
Approved by: https://github.com/jgong5, https://github.com/ezyang
2024-01-20 09:30:23 +00:00
CaoE
c9528a11dd Add Half support for masked_softmax on CPU (#117028)
Add Half support for `masked_softmax` on CPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117028
Approved by: https://github.com/jgong5, https://github.com/cpuhrsch
2024-01-18 08:59:20 +00:00
vfdev-5
1a57c18760 Fixed cuda grads for interpolate::trilinear on non-contig grad output (#117373)
Fixes #113642

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117373
Approved by: https://github.com/lezcano
2024-01-15 18:05:47 +00:00
vmoens
6f0f4f12ca [BugFix] Prevent LSTM to run with wrong input shape (#115542)
Fixes #114874
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115542
Approved by: https://github.com/mikaylagawarecki
2024-01-11 02:57:09 +00:00
Xinya Zhang
e3ca7346ce Re-add initial Flash Attention support on ROCM (#115981)
Note about the Updates:

This PR:
1. skips more flash attention related UTs on MI200
2. Fix additional ATen compiling errors after hipification
3. Fix the author "root" of a specific commit
4. Includes the patch from Nikita in favor of block level static initialization.

CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge.

Original PR (https://github.com/pytorch/pytorch/pull/114309) Note:

This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- Only supports power of two sequence lengths.
- No support for varlen APIs.
- Only support head dimension 16,32,64,128.
- Performance is still being optimized.

Fixes #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115981
Approved by: https://github.com/malfet
2024-01-04 22:21:31 +00:00
Aaron Gokaslan
3fe437b24b [BE]: Update flake8 to v6.1.0 and fix lints (#116591)
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
  - `assert(a == b)` -> `assert a == b`
  - `if(x > y or y < z):`->`if x > y or y < z:`
  - And `return('...')` -> `return '...'`

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
2024-01-03 06:04:44 +00:00
Sun, Jiayi
c173a9d9b3 add Half support for layer_norm on CPU (#99590)
### Testing
Single socket (icx, 32cores):
| shape | fp32 forward (ms) | fp16 forward (ms) | mixed fp32 fp16 forward (ms) | fp32 backward (ms) | fp16 backward (ms) | mixed fp32 fp16 backward (ms) |
| -- | -- | -- | -- | -- | -- | -- |
| (1, 8, 16) | 0.012 | 0.011 | 0.011 | 0.051 | 0.051 | 0.050 |
| (8 ,8, 16) | 0.013 | 0.013 | 0.013 | 0.054 | 0.053 | 0.051 |
| (32, 8, 16) | 0.015 | 0.014 | 0.014 | 0.059 | 0.054 | 0.052 |
| (64, 128, 56, 56) | 1.875 | 0.790 | 1.016 | 12.845 | 7.151 | 6.985 |
| (64, 128, 256, 256) | 50.226 | 25.462 | 35.736 | 328.957 | 179.615 | 175.618 |

Single core (icx):

| shape | fp32 forward (ms) | fp16 forward (ms) | mixed fp32 fp16 forward (ms) | fp32 backward (ms) | fp16 backward (ms) | mixed fp32 fp16 backward (ms) |
| -- | -- | -- | -- | -- | -- | -- |
| (1, 8, 16) | 0.012 | 0.011 | 0.011 | 0.040 | 0.041 | 0.041 |
| (8 ,8, 16) | 0.012 | 0.012 | 0.012 | 0.042 | 0.042 | 0.042 |
| (32, 8, 16) | 0.027 | 0.014 | 0.014 | 0.048 | 0.048 | 0.046 |
| (64, 128, 56, 56) | 58.054 | 11.034 | 17.928 | 108.603 | 48.816 | 50.244 |
| (64, 128, 256, 256) | 1327.758 | 352.394 | 496.994 | 2846.182 | 1224.247 | 1218.422 |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99590
Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/cpuhrsch
2023-12-20 01:11:15 +00:00
eqy
d55365dc05 [CUDA] Workaround shmem limit for certain input sizes in AdaptiveAvgPool1D (#115231)
Reference issue #68248

CC @ptrblck @malfet @xwang233

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115231
Approved by: https://github.com/mikaylagawarecki
2023-12-19 22:40:10 +00:00
PyTorch MergeBot
c006c8b50e Revert "markDynamoStrictTest some more (#115885)"
This reverts commit 55ce4693ff.

Reverted https://github.com/pytorch/pytorch/pull/115885 on behalf of https://github.com/atalman due to OSSCI oncall, broke inductor ([comment](https://github.com/pytorch/pytorch/pull/115885#issuecomment-1858409669))
2023-12-15 19:51:24 +00:00
rzou
55ce4693ff markDynamoStrictTest some more (#115885)
Featuring
test_native_mha.py
test_nn.py
test_prims.py
test_schema_check.py
test_serialization.py
test_show_pickle.py
test_sort_and_select.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115885
Approved by: https://github.com/voznesenskym
ghstack dependencies: #115845, #115855, #115856, #115857, #115858, #115870, #115871, #115879
2023-12-15 13:19:52 +00:00
eqy
9056903b09 [CUDA] 64-bit indexing for avg_pool_backward (#114193)
Fixes #113833

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114193
Approved by: https://github.com/malfet
2023-12-15 03:58:46 +00:00