## Summary
Adds missing type annotations to `torch.nn.init` and removes `# mypy: allow-untyped-defs` since all functions are now properly typed.
## Changes
- Added missing type annotations to initialization functions in the module.
- Added missing typing imports: `Any`, `Callable`, `Union`
- Removed `# mypy: allow-untyped-defs` comment
- Create Literal types for kaiming initialization mode and nonlinearity.
- Created `__all__`
## Why
Better IDE support, catches type errors earlier, and brings the module up to PyTorch's typing standards. No runtime changes - purely additive typing improvements.
Tested with existing test suite and lintrunner.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154504
Approved by: https://github.com/Skylion007
I found the same issue as #147490 (@jibril-b-coulibaly).
There's an equivalent in the [doc-string](https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html#rnn) of `torch.nn.RNN`:
```python
# Efficient implementation equivalent to the following with bidirectional=False
def forward(x, hx=None):
if batch_first:
x = x.transpose(0, 1)
seq_len, batch_size, _ = x.size()
if hx is None:
hx = torch.zeros(num_layers, batch_size, hidden_size)
h_t_minus_1 = hx
h_t = hx
output = []
for t in range(seq_len):
for layer in range(num_layers):
h_t[layer] = torch.tanh(
x[t] @ weight_ih[layer].T
+ bias_ih[layer]
+ h_t_minus_1[layer] @ weight_hh[layer].T
+ bias_hh[layer]
)
output.append(h_t[-1])
h_t_minus_1 = h_t
output = torch.stack(output)
if batch_first:
output = output.transpose(0, 1)
return output, h_t
```
However there's something wrong.
1. Like mentioned in #147490, line 499 is wrong
fb55bac3de/torch/nn/modules/rnn.py (L499)
The **input for RNNCell should be different** for different layers.
2. The code contains several hidden **reference-related issues** that may result in unintended modifications to tensors. For example in line 504, this causes all elements in the final output list to point to the same tensor.
fb55bac3de/torch/nn/modules/rnn.py (L504)
3. Some variable is not **defined**. Despite being a relatively minor issue in annotation, it can lead to significant confusion for those who are new to the concept. For example `weight_ih` in line 499
fb55bac3de/torch/nn/modules/rnn.py (L499)
So, i write a runnable version to make it more clear:
```python
# Efficient implementation equivalent to the following with bidirectional=False
rnn = nn.RNN(input_size, hidden_size, num_layers)
params = dict(rnn.named_parameters())
def forward(x, hx=None, batch_first=False):
if batch_first:
x = x.transpose(0, 1)
seq_len, batch_size, _ = x.size()
if hx is None:
hx = torch.zeros(rnn.num_layers, batch_size, rnn.hidden_size)
h_t_minus_1 = hx.clone()
h_t = hx.clone()
output = []
for t in range(seq_len):
for layer in range(rnn.num_layers):
input_t = x[t] if layer == 0 else h_t[layer - 1]
h_t[layer] = torch.tanh(
input_t @ params[f"weight_ih_l{layer}"].T
+ h_t_minus_1[layer] @ params[f"weight_hh_l{layer}"].T
+ params[f"bias_hh_l{layer}"]
+ params[f"bias_ih_l{layer}"]
)
output.append(h_t[-1].clone())
h_t_minus_1 = h_t.clone()
output = torch.stack(output)
if batch_first:
output = output.transpose(0, 1)
return output, h_t
```
This code can reproduce the computation of torch.nn.RNN.
For example:
```python
import torch
import torch.nn as nn
torch.manual_seed(0)
input_size, hidden_size, num_layers = 3, 5, 2
rnn = nn.RNN(input_size, hidden_size, num_layers)
params = dict(rnn.named_parameters())
x = torch.randn(10, 4, 3)
official_imp = rnn(x)
my_imp = forward(x)
assert torch.allclose(official_imp[0], my_imp[0])
assert torch.allclose(official_imp[1], my_imp[1])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153620
Approved by: https://github.com/mikaylagawarecki
Fixes#147336
## Context
NCU analysis of the fp8 flex attention perf issue in #147336 showed an unexpected increase in shared memory access bank conflicts when loading the V tensor from HBM to SRAM.
Bringing this to the attention of triton developer @davidberard98 he identified the memory layout of the tensor in HBM to be causing non-pipelined loads into SRAM, causing the slowdown.
To summarize:
In flex attention when performing the FP8 GEMM `softmax_scores @ V` the right operand V must be in column-major memory layout. However, the `tl.load` of V blocks from HBM to SRAM cannot be pipelined if the V tensor isn't column-major in HBM already, leading to substantial performance degradation.
This is because triton does not perform async copies with the `cp.async` PTX instruction if the number of contiguous bytes is less than 4 (see [here](81f93f2c8e/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp (L403))).
i.e., when loading 4 bytes of contiguous data from a tensor stored in row-major in HBM, we have to perform 4 separate non-contiguous writes to SRAM to place those bytes in their new location in the col-major layout in SRAM. Thus the load is not a candidate for pipelining w/ cp.async and just moves data to registers then performs a series of single byte stores.
## Fix summary
- To fix this, we should enforce memory layouts for Q, K, V in FlexAttention when fp8 is being used, to ensure they each exist in HBM in the necessary memory layout to facilitate pipelined loads into SRAM ahead of the FP8 GEMMs
## Benchmarks
Rerunning the repro we see fp8 runtime is reduced from 120% of bf16 to 76% of bf16 runtime.
Before fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 19:07:33,402 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 19:07:35,885 - flex_bench - INFO - bf16: 424.87228804347734 us
2025-05-11 19:07:35,893 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 19:07:37,319 - flex_bench - INFO - fp8e4m3: 515.714000000001 us
```
After fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 17:34:38,223 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 17:34:41,157 - flex_bench - INFO - bf16: 423.4662032967036 us
2025-05-11 17:34:41,167 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 17:34:42,917 - flex_bench - INFO - fp8e4m3: 326.3694803493453 us
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153357
Approved by: https://github.com/ngimel, https://github.com/davidberard98
C++ Reducer is silently incorrect under CA, its implementation is no-oping the collective. I'm guessing that it was no-op'd because in DDP + python reducer, the C++ reducer is still being initialized.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152735
Approved by: https://github.com/fegin
ghstack dependencies: #153300, #152689
Fixes#147336
## Context
NCU analysis of the fp8 flex attention perf issue in #147336 showed an unexpected increase in shared memory access bank conflicts when loading the V tensor from HBM to SRAM.
Bringing this to the attention of triton developer @davidberard98 he identified the memory layout of the tensor in HBM to be causing non-pipelined loads into SRAM, causing the slowdown.
To summarize:
In flex attention when performing the FP8 GEMM `softmax_scores @ V` the right operand V must be in column-major memory layout. However, the `tl.load` of V blocks from HBM to SRAM cannot be pipelined if the V tensor isn't column-major in HBM already, leading to substantial performance degradation.
This is because triton does not perform async copies with the `cp.async` PTX instruction if the number of contiguous bytes is less than 4 (see [here](81f93f2c8e/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp (L403))).
i.e., when loading 4 bytes of contiguous data from a tensor stored in row-major in HBM, we have to perform 4 separate non-contiguous writes to SRAM to place those bytes in their new location in the col-major layout in SRAM. Thus the load is not a candidate for pipelining w/ cp.async and just moves data to registers then performs a series of single byte stores.
## Fix summary
- To fix this, we should enforce memory layouts for Q, K, V in FlexAttention when fp8 is being used, to ensure they each exist in HBM in the necessary memory layout to facilitate pipelined loads into SRAM ahead of the FP8 GEMMs
## Benchmarks
Rerunning the repro we see fp8 runtime is reduced from 120% of bf16 to 76% of bf16 runtime.
Before fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 19:07:33,402 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 19:07:35,885 - flex_bench - INFO - bf16: 424.87228804347734 us
2025-05-11 19:07:35,893 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 19:07:37,319 - flex_bench - INFO - fp8e4m3: 515.714000000001 us
```
After fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 17:34:38,223 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 17:34:41,157 - flex_bench - INFO - bf16: 423.4662032967036 us
2025-05-11 17:34:41,167 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 17:34:42,917 - flex_bench - INFO - fp8e4m3: 326.3694803493453 us
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153357
Approved by: https://github.com/ngimel, https://github.com/davidberard98
When we do `torch.compile(module)`, we eventually end up returning a new
`OptimizedModule` instance, whose `forward` method is the result of
`torch.compile(mod.__call__)`, meaning it already captures all the extra
logic (e.g., hook firing) for the compiled module.
`OptimizedModule` also inherits `nn.module.__call__`, and thus
has its own hook logic. This is useful for torchao, which injects module
forward hooks to run in eager for quantization purposes.
However, this might create unexpected behavior for global module hooks,
because `torch.compile(module)` causes the hook to fire one extra time
for `OptimizedModule`, when compared to eager.
To preserve BC, we simply emit a warning for this behavior, and let
users decide what to do. This is reasonable because the global module
hooks are documented to be used for debugging/profiling purposes only.
Fixes#149502
Differential Revision: [D74611716](https://our.internmc.facebook.com/intern/diff/D74611716)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152740
Approved by: https://github.com/anijain2305, https://github.com/zou3519
Summary: add one option to allow skipping all reduce unused parameters, this could help improve training throughput significantly when the number of unused parameters is large in the model.
Test Plan: unit tests, CI
Differential Revision: D72282069
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151503
Approved by: https://github.com/mrshenli