## PR
There are a few cases that my previous PR (#153220) didn't cover.
1. The LHS/RHS matters. Today, if you do `torch._check(lhs == rhs)` then it will show up as a deferred runtime assert with `Eq(lhs, rhs)`.
2. There can be transitive replacements. For example, expr1 -> expr2 -> u0. `test_size_with_unbacked_add_expr_transitive` tests for this.
3. An unbacked symint expr may not have a replacement that's purely a symbol, for instance, it could be another expression. `test_size_with_unbacked_add_and_mul_expr` tests for this.
## Device assertion msg
```
/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [4,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.
...
/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [6,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.
```
## Autotuning code setup
This is the autotuning code for a concat kernel which takes input tensors (`in_buf`) and writes them to the (`out_buf`).
It's important to note the size of `in_buf0` is the same as `in_buf1` don't match along dim=0. This is bad because all concat inputs must share the same size for each dim except for the concat dim (here that's dim=1).
```
in_buf0 = generate_example_value(size=(u1 + s0, 256)) # concrete size is (17900, 256)
in_buf1 = generate_example_value(size=(u0, 10)) # concrete size is (8192, 10)
...
out_buf = generate_example_value(size=(u1 + s0, 266)) # concrete size is (17900, 256+10)
triton_poi_fused_cat_1.run(in_buf0, in_buf1, ..., out_buf, xnumel=(u1 + s0) * 266 ...)
```
If we look into the kernel code, you'll see that `tmp9` loads `in_buf1` (our incorrectly shaped input tensor). There is also a mask to prevent OOB loads.
- `tmp6` makes sure we're only loading with the `xindex` from 256 to 264.
- `xmask` makes sure we're only loading with the `xindex` within `xnumel`.
- `tmp6 & xmask` together is essentially checking `0 ≤ x0 < u1 + s0` and `256 ≤ x1 < 264`.
The mask logic is correct, however, `in_buf1` has the shape `[8192, 10]` this means any load where `8192 ≤ x0 < u1 + s0` will be an OOB load.
```
def triton_poi_fused_cat_1(in_buf0, in_buf1, ... out_buf, xnumel, XBLOCK):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)
xmask = xindex < xnumel
x0 = (xindex % 264)
x1 = xindex // 264
...
tmp6 = x0 >= tl.full([1], value=256)
tmp9 = tl.load(in_buf1 + (x1), tmp6 & xmask)
# device assertion is thrown here
tl.device_assert(((0 <= tl.broadcast_to(tmp13, [XBLOCK])) & (tl.broadcast_to(tmp13, [XBLOCK]) < ks0)) | ~(xmask & tmp6), "index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0")
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153768
Approved by: https://github.com/jingsh
https://github.com/pytorch/pytorch/pull/152708 expanded support of `get_estimated_runtime` to many more types of `SchedulerNodes`. This caused an increase in compile time because we're always calling `get_estimated_runtime` to populate the metrics table. This PR adds a flag for this logging, which reduces the instruction count by 8%. Long term, we should probably merge metrics.py with TORCH_LOGS/tlparse (suggestion from @xmfan).
Update: added support for TORCH_LOGS for the metrics logging.
Test Plan:
mm_loop.py and many existing tests cover.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153506
Approved by: https://github.com/eellison
This PR adds code generation for CK-tile based universal gemm kernels to the CK backend for Inductor, and adds these kernels to autotune choices.
Unlike legacy-CK based kernels (which are generated by parsing the CK instances from CK library), we generate the set of instances by manually specifying the tuning parameters.
This PR introduces a new template for code generation, and compilation/autotuning is handled by the existing infrastructure.
Points of discussion:
* For simplicity and reduced coupling with CK, the instance filter checks only data type and layout, and doesn't check the alignment requirement - meaning that more instances will be compiled than necessary - while keeping the code generation independent from internal CK logic which checks the alignment validity at runtime
* CK-tile instances are enabled whenever legacy-CK instances are enabled. A config knob could be introduced to differentiate between the instance types if that's needed
* Whether gemm problem size K is ever dynamic, since whenever it's not a compile-time constant, we need to perform a runtime dispatch between several kernels
** Testing **
Use the existing tests in `test/inductor/test_ck_backend.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152341
Approved by: https://github.com/chenyang78
Work around issues like #153960, #152623
NCCL 2.26 seems to introduce random hang in non-blocking API mode. This PR opts out of non-blocking mode to work around it. Previously torch turned it on by default in eager init (i.e. `device_id` passed) to avoid init overhead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154055
Approved by: https://github.com/atalman
Previously when we lower backward AOT due to symints, the post grad passes would leave the bw_module in a non-runnable state. This caused issues when compiled autograd tried to trace at runtime. So we had inductor operate on a deepcopy of bw_module.
But with https://github.com/pytorch/pytorch/issues/153993, we see that deepcopying real tensors will fail under fake mode due to the device type mismatch between the fake tensors ("meta" device) and the real tensor. So by disabling fake mode, we avoid these errors. This change is a strict improvement over current, but it does reveal that this deepcopy can theoretically cause OOMs.
FIXES https://github.com/pytorch/pytorch/issues/153993
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153999
Approved by: https://github.com/jamesjwu, https://github.com/bdhirsh
I found the same issue as #147490 (@jibril-b-coulibaly).
There's an equivalent in the [doc-string](https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html#rnn) of `torch.nn.RNN`:
```python
# Efficient implementation equivalent to the following with bidirectional=False
def forward(x, hx=None):
if batch_first:
x = x.transpose(0, 1)
seq_len, batch_size, _ = x.size()
if hx is None:
hx = torch.zeros(num_layers, batch_size, hidden_size)
h_t_minus_1 = hx
h_t = hx
output = []
for t in range(seq_len):
for layer in range(num_layers):
h_t[layer] = torch.tanh(
x[t] @ weight_ih[layer].T
+ bias_ih[layer]
+ h_t_minus_1[layer] @ weight_hh[layer].T
+ bias_hh[layer]
)
output.append(h_t[-1])
h_t_minus_1 = h_t
output = torch.stack(output)
if batch_first:
output = output.transpose(0, 1)
return output, h_t
```
However there's something wrong.
1. Like mentioned in #147490, line 499 is wrong
fb55bac3de/torch/nn/modules/rnn.py (L499)
The **input for RNNCell should be different** for different layers.
2. The code contains several hidden **reference-related issues** that may result in unintended modifications to tensors. For example in line 504, this causes all elements in the final output list to point to the same tensor.
fb55bac3de/torch/nn/modules/rnn.py (L504)
3. Some variable is not **defined**. Despite being a relatively minor issue in annotation, it can lead to significant confusion for those who are new to the concept. For example `weight_ih` in line 499
fb55bac3de/torch/nn/modules/rnn.py (L499)
So, i write a runnable version to make it more clear:
```python
# Efficient implementation equivalent to the following with bidirectional=False
rnn = nn.RNN(input_size, hidden_size, num_layers)
params = dict(rnn.named_parameters())
def forward(x, hx=None, batch_first=False):
if batch_first:
x = x.transpose(0, 1)
seq_len, batch_size, _ = x.size()
if hx is None:
hx = torch.zeros(rnn.num_layers, batch_size, rnn.hidden_size)
h_t_minus_1 = hx.clone()
h_t = hx.clone()
output = []
for t in range(seq_len):
for layer in range(rnn.num_layers):
input_t = x[t] if layer == 0 else h_t[layer - 1]
h_t[layer] = torch.tanh(
input_t @ params[f"weight_ih_l{layer}"].T
+ h_t_minus_1[layer] @ params[f"weight_hh_l{layer}"].T
+ params[f"bias_hh_l{layer}"]
+ params[f"bias_ih_l{layer}"]
)
output.append(h_t[-1].clone())
h_t_minus_1 = h_t.clone()
output = torch.stack(output)
if batch_first:
output = output.transpose(0, 1)
return output, h_t
```
This code can reproduce the computation of torch.nn.RNN.
For example:
```python
import torch
import torch.nn as nn
torch.manual_seed(0)
input_size, hidden_size, num_layers = 3, 5, 2
rnn = nn.RNN(input_size, hidden_size, num_layers)
params = dict(rnn.named_parameters())
x = torch.randn(10, 4, 3)
official_imp = rnn(x)
my_imp = forward(x)
assert torch.allclose(official_imp[0], my_imp[0])
assert torch.allclose(official_imp[1], my_imp[1])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153620
Approved by: https://github.com/mikaylagawarecki
Summary: Remove anonymous namespace in model_container.h to fix the following compiler warning,
```
warning: ‘torch::aot_inductor::AOTInductorModelContainer’ has a field ‘torch::aot_inductor::AOTInductorModelContainer::constant_folded_’ whose type uses the anonymous namespace [-Wsubobject-linkage]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154033
Approved by: https://github.com/chenyang78
Added AOTIModelContainerRunnerMps and a shim for mps fallback ops.
I also added a mps-specific shim which contains one operator, which will be used to set arguments being passed to the Metal kernel:
```
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_set_arg(
AOTIMetalKernelFunctionHandle func,
unsigned idx,
AtenTensorHandle tensor);
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153964
Approved by: https://github.com/malfet, https://github.com/desertfire
In a model, we see ~~ 40% of the time in mm/addmm tuning. The model have 2000 mm,
many of which receives the same input shapes.
with autotune enabled, this become expensive, while we already cache auto tuning results, we
did not used to cache the generation of the python code and the loading for each config that we autotune on.
This diff handles the code generation part (template expansions) a previous diff handled the loading part.
This is expected to save 20% of the model I am working on.
How do we do the caching?
For a given configurations and input layout, the generated code is always the same. One caveat is that
some other information collected during code generation are input dependent (namely depends on inputs
names and symbol names in inputs). and not just layout. !
To handle those we use a record and replay approach, where we record the functions that are called during
code generation that effect those outputs and replay them at a cache hit.
Effect on the current benchmark on a local run on dev server.
mm_loop. 24115830838 -> 18362098019
mm_loop_dynamic 30506097176-> 25697270062
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151773
Approved by: https://github.com/eellison
Fixes#153646
This PR refactors the logging behavior in the FX pass insert_deferred_runtime_asserts and runtime_assert.py to separate verbose/intermediate graph logs from the final output graph log. All verbose logs generated during the FX pass are now routed to a new artifact logger, graph_code_verbose, while only the final output graph remains logged to the original graph_code artifact.
Changes
- Added a new artifact logger: [graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code_verbose")]
- Updated all verbose/intermediate FX pass logs in [insert_deferred_runtime_asserts] to use the new graph_code_verbose artifact.
- Ensured that only the final output graph is logged to the original graph_code artifact.
- No changes to the FX pass logic or output—only logging behavior is affected.
Notes
This change is backward-compatible and does not affect the functional behavior of FX passes.
No changes to user-facing APIs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153775
Approved by: https://github.com/williamwen42
Finally, this PR adds BundledAOTAutogradCacheEntry. A BundledAOTAutogradCacheEntry is an AOTAutogradCacheEntry that saves the entire CompiledFxGraph directly in the entry.
This has some advantages:
- No more dependency on FxGraphCache at all
- Clearing FxGraphCache does not result in AOTAutogradCache miss
- Simpler logic, as BundledAOTAutogradCacheEntry has everything you need to load a full compiled python wrapper from a dynamo output
We plan to use BundledAOTAutogradCacheEntry for precompile. There's also a question of whether we want to use it for regular caching — the main disadvantage of this is having to save the same CompiledFxGraph twice, once in Inductor cache and once for AOTAutogradCache. With MegaCaching, this *could* be a regression in total cache size (as well as a minor cold start regression, as you have to save the same graph twice). I will import this and measure the mega cache space complexity, and if it looks good I'll enable it by default for caching as well.
On warm start, if AOTAutogradCache hits, you won't have to load inductor at all, so warm start overhead should be unaffected.
Differential Revision: [D74593304](https://our.internmc.facebook.com/intern/diff/D74593304)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152840
Approved by: https://github.com/zhxchen17
Motivation:
By default, we are tuning the cutlass backend kernels on 3 swizzles. There are runtime params, so they share the same underlying kernel, which saves a lot of compilation time. However, autotuning all combinations of {configs} x {swizzles} is still expensive.
Observations:
Winner of the {configs} x {swizzles} autotuning is the same as if we do a greedy search: first find the top X winners of {configs} with swizzle 2 (hardcoded), then autotune on the {top X winner configs} x {swizzles}. In other words, we can use a Greedy algorithm to reduce autotuning time.
I attach the logs below. This somewhat depends on what X is, but a number like 5-10 works pretty well from empirical observations.
Logs:
Baseline:
https://gist.github.com/henrylhtsang/9a604f150a270dc19524f72a5d4dfac2
```
AUTOTUNE mm(2048x2048, 2048x2048)
strides: [2048, 1], [1, 2048]
dtypes: torch.bfloat16, torch.bfloat16
cuda_cutlass_gemm_1776 0.0291 ms 100.0% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_1777 0.0291 ms 100.0% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_1778 0.0291 ms 100.0% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_1800 0.0293 ms 99.2% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_1801 0.0293 ms 99.2% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_1802 0.0293 ms 99.2% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_9012 0.0294 ms 98.9% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_9013 0.0294 ms 98.9% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_9014 0.0294 ms 98.9% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_8940 0.0296 ms 98.3% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_8941 0.0296 ms 98.3% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_8942 0.0296 ms 98.3% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_8934 0.0297 ms 98.1% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_8935 0.0297 ms 98.1% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_8936 0.0297 ms 98.1% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_2001 0.0297 ms 97.8% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_2x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_2002 0.0297 ms 97.8% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_2x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_2003 0.0297 ms 97.8% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_2x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_1848 0.0298 ms 97.6% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_1849 0.0298 ms 97.6% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_1850 0.0298 ms 97.6% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_8964 0.0298 ms 97.6% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_8965 0.0298 ms 97.6% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_8966 0.0298 ms 97.6% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_8958 0.0298 ms 97.5% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_8959 0.0298 ms 97.5% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_8960 0.0298 ms 97.5% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_1929 0.0302 ms 96.4% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_1930 0.0302 ms 96.4% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_1931 0.0302 ms 96.4% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_1x1x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_1770 0.0302 ms 96.2% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_1771 0.0302 ms 96.2% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_1772 0.0302 ms 96.2% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_1953 0.0302 ms 96.2% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_1x2x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_1954 0.0302 ms 96.2% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_1x2x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_1955 0.0302 ms 96.2% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_1x2x1_0_tnn_align8_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_1995 0.0303 ms 96.0% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_2x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_1996 0.0303 ms 96.0% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_2x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_1997 0.0303 ms 96.0% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_2x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_1794 0.0303 ms 95.9% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_1795 0.0303 ms 95.9% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_1796 0.0303 ms 95.9% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_1842 0.0303 ms 95.9% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_1843 0.0303 ms 95.9% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_1844 0.0303 ms 95.9% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_9006 0.0304 ms 95.7% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=1
cuda_cutlass_gemm_9007 0.0304 ms 95.7% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=2
cuda_cutlass_gemm_9008 0.0304 ms 95.7% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=4
cuda_cutlass_gemm_1923 0.0306 ms 95.2% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_1x1x1_0_tnn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=1
```
with prescreening:
```
AUTOTUNE mm(147456x6144, 6144x2048)
strides: [6144, 1], [2048, 1]
dtypes: torch.bfloat16, torch.bfloat16
cutlass_1a5e81af 4.5469 ms 100.0% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=1
cutlass_aa6f899c 4.6328 ms 98.1% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=1
cutlass_aa6f899c 4.6836 ms 97.1% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=4
cutlass_161b8b81 4.7224 ms 96.3% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=1
cutlass_161b8b81 4.7234 ms 96.3% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=2
cutlass_161b8b81 4.7274 ms 96.2% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=4
cutlass_853b6347 4.7369 ms 96.0% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=1
cutlass_aa6f899c 4.7404 ms 95.9% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=2
cutlass_161b8b81 4.7711 ms 95.3% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=8
cutlass_8bc6fbda 4.8148 ms 94.4% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=8
cutlass_8bc6fbda 4.8159 ms 94.4% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=1
cutlass_8bc6fbda 4.8214 ms 94.3% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=4
cutlass_8bc6fbda 4.8302 ms 94.1% cutlass3x_sm90_tensorop_s64x256x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=2
cutlass_0a1c55af 4.8487 ms 93.8% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=8
cutlass_0a1c55af 4.8527 ms 93.7% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=2
cutlass_02780d72 4.8617 ms 93.5% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_2x1x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=4
cutlass_0a1c55af 4.8737 ms 93.3% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=1
cutlass_0a1c55af 4.8738 ms 93.3% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_1x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=4
cutlass_02780d72 4.9348 ms 92.1% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_2x1x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=1
cutlass_02780d72 4.9763 ms 91.4% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_2x1x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=2
cutlass_853b6347 4.9805 ms 91.3% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=2
cutlass_1a5e81af 5.0225 ms 90.5% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=8
cutlass_853b6347 5.0271 ms 90.4% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=8
cutlass_02780d72 5.0595 ms 89.9% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_2x1x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=8
cutlass_853b6347 5.1434 ms 88.4% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=4
cutlass_c1ffa14b 5.1574 ms 88.2% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=8
cutlass_1a5e81af 5.1916 ms 87.6% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=4
cutlass_c1ffa14b 5.2018 ms 87.4% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=4
cutlass_c1ffa14b 5.2019 ms 87.4% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=1
cutlass_c1ffa14b 5.2037 ms 87.4% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_256x128x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=2
cutlass_1a5e81af 5.5329 ms 82.2% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_2x1x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma swizzle=2
cutlass_aa6f899c 11.5046 ms 39.5% cutlass3x_sm90_tensorop_s64x128x16gemm_bf16_bf16_f32_void_bf16_128x256x64_1x2x1_0_ttn_align8_warpspecialized_cooperative_epi_tma swizzle=8
SingleProcess AUTOTUNE benchmarking takes 1.9526 seconds and 0.0352 seconds precompiling for 32 choices
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153335
Approved by: https://github.com/eellison
Fixes
```
In file included from /Users/malfet/git/pytorch/pytorch/torch/csrc/distributed/c10d/SymmetricMemory.cpp:1:
/Users/malfet/git/pytorch/pytorch/torch/csrc/distributed/c10d/SymmetricMemory.hpp:77:4: warning: extra ';' after member function definition [-Wextra-semi]
77 | };
| ^
/Users/malfet/git/pytorch/pytorch/torch/csrc/distributed/c10d/SymmetricMemory.hpp:81:4: warning: extra ';' after member function definition [-Wextra-semi]
81 | };
| ^
2 warnings generated.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154034
Approved by: https://github.com/Skylion007
# Motivation
This PR adds scalar tensor (`t.elem()==1 and t.sizes().empty() == true and t.dim()=0` )handling in addmm, baddmm. The issue is found during the process of oneDNN upgradation. Found that the new version of oneDNN requires the post-binary (self in addmm) has same dimension as the one of output tensor. Now we need explicitly expand the shape of `self` tensor. Former version dnnl may help us do the broadcasting inside.
This PR could fix issues in https://github.com/intel/torch-xpu-ops/issues/1612 and CI error in https://github.com/pytorch/pytorch/pull/151767.
# Implementation
We treat the scalar tensor as normal tensor by `unsqueeze` it as 1 dimension tensor. Accompanied with the existing shape handle logic, it would be further `unsqueeze` to 2D or 3D shape.
UT testing
```
python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoXPU.test_comprehensive_addmm_xpu_float32
python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoXPU.test_comprehensive_addmv_xpu_float32
python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoXPU.test_comprehensive_baddbmm_xpu_float16
python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoXPU.test_comprehensive_baddbmm_xpu_float32
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153051
Approved by: https://github.com/EikanWang, https://github.com/guangyey
Summary:
# Context
The MTIA New Aten Backend work is essentially to move MTIA operators from pytorch out-of-tree to in-tree, with following benefits:
1. Avoid duplicate code copied from pytorch, e.g. view ops implementation, util functions.
2. Utilize TensorIterator and structured kernel codegen, avoid manual implementation of broadcasting, dtype casting, asserting, etc.
3. Eliminate MTIA's own codegen flow, which is unnecessary complexity.
4. Overall make MTIA's aten backend more pytorch native.
Differential Revision: D74672464
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153670
Approved by: https://github.com/albanD, https://github.com/nautsimon