### Before this PR:
`torch.utils.swap_tensors(a, b)` required the `use_count` of `a` and `b` to be 1
```python
a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here would fail due to the reference held by AccumulateGrad node, which is not cleaned up after backward
# torch.utils.swap_tensors(a, b)
del out
# Calling swap_tensors here would pass
torch.utils.swap_tensors(a, b)
```
### After this PR:
`torch.utils.swap_tensors(a, b)` requires the `use_count` of `a` and `b` to be 1 or 2 IF the second reference is held by `AccumulateGrad`
A pre-hook will be registered on the `AccumulateGrad` node so that it will fail if it is called (i.e. if user attempts to backward through the graph).
```python
a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here is ok
torch.utils.swap_tensors(a, b)
# If we ever backward to the AccumulateGrad node it will error that it was poisoned by swap_tensors
```
### Application to `nn.Module`
This issue is especially pertinent in context of `nn.Module` where parameters will have `AccumulateGrad` nodes initialized after forward. Specifically, this is intended to address https://github.com/pytorch/pytorch/pull/126814#issuecomment-2127777866. Previously, this would fail at the `m.cpu()` but we want users to be able to do something like the following, and instead raise an error if the user ever attempts to backward through the poisoned `AccumulateGrad` node
```python
import torch
import torch.nn as nn
m = nn.Linear(3, 5)
inp = torch.randn(2, 3)
out = m(inp)
out.sum().backward()
m.cpu()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127313
Approved by: https://github.com/soulitzer
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves#126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
# Motivation
## for `torch.amp.GradScaler`,
- `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`.
- `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`.
So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`.
## for `custom_fwd` and `custom_bwd`,
this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU.
So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`.
# Additional Context
Add UT to cover the deprecated warning.
No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them.
To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126527
Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/janeyx99, https://github.com/EikanWang
Fixes#71398
Add `__reduce__` and `__setstate__` methods for `torch._C.Generator`.
`__reduce__` returns a tuple of 3 values:
1. `torch.Generator` itself.
2. A one-element tuple containing the `torch.device` to create the `Generator` with, since this cannot be changed after the object is created.
3. The state, a three-element tuple: the initial seed, the offset (or `None` if a CPU `Generator`), and the RNG state tensor.
`__setstate__` calls `manual_seed`, `set_offset` (if not `None`), and `set_state` on each respective element of the state.
Added test demonstrating successful reserialization with cpu and cuda `Generator`s.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126271
Approved by: https://github.com/ezyang
Fixes#123451 (only addresses test_torch.py cases)
This PR solves the specific task to update `test_grad_scaling_autocast` and `test_params_invalidated_with_grads_invalidated_between_unscale_and_step` in `test/test_torch.py` to use the new OptimizerInfo infrastructure.
I have combined tests that call `_grad_scaling_autocast_test` into one called `test_grad_scaling_autocast` and used `_get_optim_inputs_including_global_cliquey_kwargs` to avoid hard-coded configurations.
```
$ lintrunner test/test_cuda.py
ok No lint issues.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125538
Approved by: https://github.com/janeyx99
As per title.
This ensures that all the places where we assume the method defined in _tensor.py do exist.
BC-Breaking: This is bc-breaking as the user cannot subclass this private class anymore.
You should replace any use of _TensorBase to Tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125558
Approved by: https://github.com/ezyang
Fixes#121965
This PR hopes to add support complex numbers in the scatter/gather related kernels. For brevity, I will only include `complex<float>` for now as `complex<double>`, for example, will be more complicated.
C++ unit tests are currently passing alongside tests in `test_scatter_gather_ops.py`. Python test suites also seem to be passing.
Please keep the following in mind:
1) I think this is my first time using Pytorch.
2) This is my first contribution to Pytorch.
Environment:
3080 & WSL 2. `nvcc` is at 12.4.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124809
Approved by: https://github.com/mikaylagawarecki
Fixes#121965
This PR hopes to add support complex numbers in the scatter/gather related kernels. For brevity, I will only include `complex<float>` for now as `complex<double>`, for example, will be more complicated.
C++ unit tests are currently passing alongside tests in `test_scatter_gather_ops.py`. Python test suites also seem to be passing.
Please keep the following in mind:
1) I think this is my first time using Pytorch.
2) This is my first contribution to Pytorch.
Environment:
3080 & WSL 2. `nvcc` is at 12.4.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124809
Approved by: https://github.com/eqy, https://github.com/mikaylagawarecki
Update ruff to 0.4.1 .
This version fixes a lot false negatives/false positives, is 20-40% faster, and has various other bug fixes.
Below is a before and after table showing the execution time of ruff lint and ruff format in milliseconds courtesy of https://astral.sh/blog/ruff-v0.4.0
| Repository | Linter (v0.3) | Linter (v0.4) | Formatter (v0.3) | Formatter (v0.4) |
|----------------------------------------------------|---------------|---------------|------------------|------------------|
| [pytorch/pytorch](https://github.com/pytorch/pytorch) | 328.7 | 251.8 | 351.1 | 274.9 |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124549
Approved by: https://github.com/ezyang
On par with `CUDA` implementation.
For `autocast` logic, same with `CUDA` + `Fused Adam`:
- check inf in `gradscalar.step`
- In fused kernel, if there is `inf`, do nothing. If not, unscale the grad ( also write back) and update the param.
**TestPlan**:
```
# extend CUDA only test for CPU fused adagrad
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_torch.py -k test_grad_scaling_autocast_fused
# extend fused test
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
python test_optim.py -k test_can_load_older_state_dict
# newly added test (follow 6b1f13ea2f/test/test_cuda.py (L1108))
python test_optim.py -k test_grad_scaling_autocast_fused_optimizers
```
**Benchmark**:
**5.1x** on 56 core SPR
**Parameter-size=1M**
**Nparams=10**
[test script](https://gist.github.com/zhuhaozhe/ef9a290ad3f8f4067b3373a3bdaa33e7)
```
numactl -C 0-55 -m 0 python bench_adam.py
non-fused 6.0174267292022705 s
fused 1.1787631511688232 s
```
**Note: Fused kernel accuracy**
The accuracy failure in CI shows a little higher than default tolerance
```
2024-04-02T06:09:16.2213887Z Mismatched elements: 21 / 64 (32.8%)
2024-04-02T06:09:16.2214339Z Greatest absolute difference: 1.5735626220703125e-05 at index (6, 6) (up to 1e-05 allowed)
2024-04-02T06:09:16.2214813Z Greatest relative difference: 1.0073336852656212e-05 at index (4, 1) (up to 1.3e-06 allowed)
```
I have debug it step by step and unfortunately we may not able to make the `fused kernel` exactly same with `non fused` one due to compiler optimizations.
For example, in non-fused impl
```
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```
and in fused impl
```
exp_avg_sq_ptr[d] = scalar_t(beta2) * exp_avg_sq_ptr[d];
// std::cout << "exp_avg_sq " << exp_avg_sq_ptr[d] << std::endl;
exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] +
scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
```
If I keep `std::cout`, I can get exactly same results in UT
```
===============param
0.6796758770942688
0.6796758770942688
```
But when I comment out it, there will be a difference
```
===============param
0.6796758770942688
0.6796759366989136
```
So I will make the tolerance a little higher than default one.
Co-authored-by: Jane Xu <janeyx@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123074
Approved by: https://github.com/jgong5, https://github.com/janeyx99
This PR intends to fix the following issue when swapping two tensors
```python
>>> import torch
>>> torch.manual_seed(5)
>>> t1 = torch.randn(2)
>>> t2 = torch.randn(3)
>>> t1
tensor([-0.4868, -0.6038])
>>> t2
tensor([-0.5581, 0.6675, -0.1974])
>>> torch.utils.swap_tensors(t1, t2)
>>> t1
tensor([-0.5581, 0.6675, -0.1974])
>>> t2
tensor([-0.4868, -0.6038])
>>> t1.fill_(0.5) # t1 back to its unswapped state :o
tensor([-0.4868, -0.6038])
```
What happens here is that in `THPVariable_Wrap` (which is used when going back from C++ --> Python), we check if the TensorImpl of the tensor to be returned already has a pointer to a PyObject in its PyObject slot. If this is the case then this object is returned.
57491d2046/torch/csrc/autograd/python_variable.cpp (L271-L292)
When we run any operation that returns the same TensorImpl (e.g. inplace op, `t.to(dtype=t.dtype)`, etc.), although `t1` now has `t2`'s TensorImpl, `t2`'s TensorImpl still has a reference to `t2`, so when we do the op on `t1` and `THPVariable_Wrap` attempts to return the pointer to the TensorImpl's PyObject, we return a pointer to `t2` instead.
The TensorImpl should have the PyObjects in their PyObjectSlots swapped as well in `swap_tensors`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116955
Approved by: https://github.com/albanD
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
- `assert(a == b)` -> `assert a == b`
- `if(x > y or y < z):`->`if x > y or y < z:`
- And `return('...')` -> `return '...'`
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
After this refactor:
* ```TorchVariable``` definition and all references are removed.
* All ```is_allowed``` references except one are removed.
- The only left one is in ```torch/_dynamo/decorators:_disallow_in_graph_helper```. It was called when users put ```disallow_in_graph``` decorator on a function. Since we use the lists in ```trace_rules``` to decide the function's trace rule, so the decorator would only be used as customer function rather than torch functions. I'll defer this to a separate decorator refactor PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116312
Approved by: https://github.com/jansel
After this refactor:
* ```TorchVariable``` definition and all references are removed.
* All ```is_allowed``` references except one are removed.
- The only left one is in ```torch/_dynamo/decorators:_disallow_in_graph_helper```. It was called when users put ```disallow_in_graph``` decorator on a function. Since we use the lists in ```trace_rules``` to decide the function's trace rule, so the decorator would only be used as customer function rather than torch functions. I'll defer this to a separate decorator refactor PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116312
Approved by: https://github.com/jansel