### Before this PR:
`torch.utils.swap_tensors(a, b)` required the `use_count` of `a` and `b` to be 1
```python
a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here would fail due to the reference held by AccumulateGrad node, which is not cleaned up after backward
# torch.utils.swap_tensors(a, b)
del out
# Calling swap_tensors here would pass
torch.utils.swap_tensors(a, b)
```
### After this PR:
`torch.utils.swap_tensors(a, b)` requires the `use_count` of `a` and `b` to be 1 or 2 IF the second reference is held by `AccumulateGrad`
A pre-hook will be registered on the `AccumulateGrad` node so that it will fail if it is called (i.e. if user attempts to backward through the graph).
```python
a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here is ok
torch.utils.swap_tensors(a, b)
# If we ever backward to the AccumulateGrad node it will error that it was poisoned by swap_tensors
```
### Application to `nn.Module`
This issue is especially pertinent in context of `nn.Module` where parameters will have `AccumulateGrad` nodes initialized after forward. Specifically, this is intended to address https://github.com/pytorch/pytorch/pull/126814#issuecomment-2127777866. Previously, this would fail at the `m.cpu()` but we want users to be able to do something like the following, and instead raise an error if the user ever attempts to backward through the poisoned `AccumulateGrad` node
```python
import torch
import torch.nn as nn
m = nn.Linear(3, 5)
inp = torch.randn(2, 3)
out = m(inp)
out.sum().backward()
m.cpu()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127313
Approved by: https://github.com/soulitzer
Fixes#121188
Prevent Segmentation Fault in 'torch._C._nn.thnn_conv2d'
Previously, calling 'torch._C._nn.thnn_conv2d' with invalid arguments for padding, stride, and kernel_size would result in a segmentation fault. This issue has been resolved by implementing argument validation (using Torch Check). Now, when invalid arguments are detected, a runtime error is raised with a debug message detailing the correct format.
Additionally, this commit includes tests to cover the three referenced cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121906
Approved by: https://github.com/janeyx99
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960
Approved by: https://github.com/malfet
Summary: To unblock training where upsamplenearest2d involves input or output tensors which are larger than 2^31. Comes up frequently in image & video applications.
Test Plan:
```
buck2 test mode/opt //caffe2/test:test_nn_cuda -- test_upsamplingnearest2d_backward_64bit_indexing
```
Benchmarking (N5207417):
```
device_ms, cpu_ms, gb/device_ms*1000
# before changes
118.03993721008301 124.09385920000001 98.72685525972494
# after changes
118.05780944824218 124.10893509999994 98.71190944734577
```
Differential Revision: D55625666
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123682
Approved by: https://github.com/ezyang
This PR proposes to keep the original order as the original state_dict, as the issue creator proposed. It also removes a bug concerning how ``_metadata`` is handled (see below), as well as other small changes to properly remove the prefix when is present.
In the original code, ``_metadata`` was handled as a ``key``.
```
# also strip the prefix in metadata if any.
if "_metadata" in state_dict:
```
This is not the case, ``_metadata`` is actually an ``attribute``. Hence, the previous condition is changed to:
```
# also strip the prefix in metadata if any.
if hasattr(state_dict, "_metadata"):
```
This PR also includes the necessary test.
Fixes#106942
Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117464
Approved by: https://github.com/mikaylagawarecki
Fixes#121093
Previously, calling the following functions with invalid padding dimensions would cause a segmentation fault:
```
torch._C._nn.replication_pad1d, torch._C._nn.replication_pad3d, torch._C._nn.replication_pad3d
```
To fix, added condition checking to raise a runtime error with a debug message instead, specifying the correct dimensions necessary.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121298
Approved by: https://github.com/mikaylagawarecki
Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is a tensor subclass.
From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run*** if the autograd graph is still alive. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](6cf1fc66e3/torch/csrc/autograd/variable.cpp (L307)). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.**
***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected.
If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.
**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117167
Approved by: https://github.com/albanD
ghstack dependencies: #118028
Description:
- Fixed error in bicubic upsampling aa=false for uint8 input. This is seen in the test suite:
```diff
- self.assertLess(diff.max(), 15)
+ self.assertLess(diff.max(), 5)
```
While reducing the input range we do not fully remove the clipping effect that's why the threshold is 5 and not around 1.
- Renamed methods
- The error is mostly visible for upsampling (smaller -> larger) mode on the boundary values
More details on the bug:
For uint8 input and antialising=False we are using separable algorithm (using temp buffers and interpolating dimensions one by one) where interpolation weights and input indices are computed and stored using index ranges: `index_min` and `index_size`; weights outside of the `index_size` are zeros. For example, for an output point we can have index_min=10 and index_size=4 and 4 non-zero weights: so the output value is computed as
```
out_value = sum([src[i + index_min] * w for i, w in zip(range(4), weights) ])
```
When computing index ranges and weights for output points near the boundaries we should clamp `index_min` between 0 and input_size and `index_size` becomes smaller than 4. This approach is OK for antialiasing=True but is not correct for antialiasing=False where weights are computed incorrectly:
```
-- output index i= 0
regular float32 approach:
source indices: [-2, -1, 0, 1] -> outbounded values are clamped to boundaries -> [0, 0, 0, 1]
interp weights: [-0.07200000000000006, 0.4600000000000001, 0.72, -0.1080000000000001]
separable uint8 approach:
source indices coming from index ranges (min, size): [0, 1]
incorrect interp weights computed with current implementation : [1.1764705882352944, -0.17647058823529432, 0.0, 0.0]
fixed interp weights in the PR: [1.108, -0.1080000000000001, 0.0, 0.0]
Note: weight value corresponding to source index 0 is 1.108 = -0.07200000000000006 + 0.4600000000000001 + 0.72 and weight value corresponding to source index 1 is -0.1080000000000001 is the same as in f32 approach.
```
Quick benchmark to ensure perfs no regression:
```
[------------------------------------------------------------------------------------ Resize ------------------------------------------------------------------------------------]
| torch (2.3.0a0+gitfda85a6) PR | torch (2.3.0a0+git0d1e705) Nightly | Speed-up: PR vs Nightly
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------------------------------
3 torch.uint8 channels_first bilinear (400, 400) -> (224, 224) aa=False | 440.996 (+-2.044) | 470.824 (+-5.927) | 1.068 (+-0.000)
3 torch.uint8 channels_first bicubic (400, 400) -> (224, 224) aa=False | 463.565 (+-1.519) | 497.231 (+-10.825) | 1.073 (+-0.000)
3 torch.uint8 channels_first bilinear (400, 400) -> (700, 700) aa=False | 1717.000 (+-28.589) | 1915.570 (+-43.397) | 1.116 (+-0.000)
3 torch.uint8 channels_first bicubic (400, 400) -> (700, 700) aa=False | 1801.954 (+-22.391) | 1981.501 (+-37.034) | 1.100 (+-0.000)
3 torch.uint8 channels_last bilinear (400, 400) -> (224, 224) aa=False | 199.599 (+-0.851) | 196.535 (+-3.788) | 0.985 (+-0.000)
3 torch.uint8 channels_last bicubic (400, 400) -> (224, 224) aa=False | 243.126 (+-0.681) | 240.695 (+-2.306) | 0.990 (+-0.000)
3 torch.uint8 channels_last bilinear (400, 400) -> (700, 700) aa=False | 686.270 (+-2.870) | 687.769 (+-17.863) | 1.002 (+-0.000)
3 torch.uint8 channels_last bicubic (400, 400) -> (700, 700) aa=False | 899.509 (+-5.377) | 899.063 (+-9.001) | 1.000 (+-0.000)
Times are in microseconds (us).
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118389
Approved by: https://github.com/NicolasHug
ghstack dependencies: #118388
Description:
- Lowered error thresholds and added input range for bicubic to make visible the inconsistency error in the implementation for upsampling (smaller -> larger) bicubic aa=false mode for uint8 input dtype
- Updated out-dated comments
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118388
Approved by: https://github.com/NicolasHug
These operators are not used and have been deprecated since #72690
(Feb 2022).
BC-breaking message:
`TorchScript` models that were exported with the deprecated
`torch.jit.quantized` API will no longer be loadable, as the required
internal operators have been removed.
Please re-export your models using the newer `torch.ao.quantization` API
instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112153
Approved by: https://github.com/jerryzh168
Note about the Updates:
This PR:
1. skips more flash attention related UTs on MI200
2. Fix additional ATen compiling errors after hipification
3. Fix the author "root" of a specific commit
4. Includes the patch from Nikita in favor of block level static initialization.
CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge.
Original PR (https://github.com/pytorch/pytorch/pull/114309) Note:
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.
Know limitations:
- Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- Only supports power of two sequence lengths.
- No support for varlen APIs.
- Only support head dimension 16,32,64,128.
- Performance is still being optimized.
Fixes#112997
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115981
Approved by: https://github.com/malfet
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
- `assert(a == b)` -> `assert a == b`
- `if(x > y or y < z):`->`if x > y or y < z:`
- And `return('...')` -> `return '...'`
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet