This PR enables additional Inductor unit tests for Intel GPU. Due to the increased number of test cases, the number of runners has been extended from 8 to 12 to prevent CI timeouts.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166047
Approved by: https://github.com/jansel
Co-authored-by: Deng, Daisy <daisy.deng@intel.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Summary:
When using inductor pattern matcher to replace graphs, the graph generated by replacement function can be missing `original_aten` metadata for the replaced nodes. This further results in inductor failing to generate a sensible kernel name, eg. `tri_poi_fused_0` , missing the aten op name.
This diff attempts to fix that by allowing tracing the graph in replacement function with `preserve_node_meta`. Included this as an option to turn on in `pattern_matcher.fwd_only` function.
Can confirm that with the fix, MTIA's pattern matcher replaced original graph with a node that has original_aten meta, and inductor generated kernel name has op name.
Test Plan:
added kernel_name check to afg_inductor_test silu test
Rollback Plan:
Differential Revision: D80183670
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160542
Approved by: https://github.com/eellison, https://github.com/bdhirsh
PR implements a pass in post_grad to fuse activation(add + mm)
This was previously done similarly here #106912 but was reverted for performance reasons. it was replaced with a pass that unfuses the activation and add from addmm/addmm_activation and let inductor handle the fusion.
however since then cuBLAS team has made a lot of perf improvements on this, will update this post with more benchmarks but preliminary benchmark show good results
perf dash board
<img width="3371" height="1240" alt="Screenshot from 2025-08-07 13-41-35" src="https://github.com/user-attachments/assets/d44d6205-b33a-4a20-9f0f-d9db176b3738" />
Relu works with both training and inference but gelu only works with inference mode due to some fundamental limitations since gelu's derivative depends on input and relu's doesnt. don't think this is fixable with the current addmm_activation API
Graph module before and after this pass
Relu(addmm)
```
graph():
%primals_1 : [num_users=1] = placeholder[target=primals_1]
%primals_2 : [num_users=2] = placeholder[target=primals_2]
%primals_3 : [num_users=2] = placeholder[target=primals_3]
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {})
%relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%addmm,), kwargs = {})
%le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {})
%permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {})
return (relu, primals_2, le, permute_1)
graph():
%primals_1 : [num_users=1] = placeholder[target=primals_1]
%primals_2 : [num_users=2] = placeholder[target=primals_2]
%primals_3 : [num_users=2] = placeholder[target=primals_3]
%_addmm_activation_default : [num_users=2] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {})
%le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%_addmm_activation_default, 0), kwargs = {})
%permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {})
return (_addmm_activation_default, primals_2, le, permute_1)
```
Gelu (addmm)
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%addmm : [num_users=4] = call_function[target=torch.ops.aten.addmm.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, %addmm), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %addmm), kwargs = {})
%mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_1, 0.044715), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%addmm, %mul_2), kwargs = {})
%mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 0.7978845608028654), kwargs = {})
%mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, 0.5), kwargs = {})
%tanh : [num_users=1] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_3,), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh, 1), kwargs = {})
%mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_4, %add_1), kwargs = {})
return (mul_5,)
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%_addmm_activation_default : [num_users=1] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {use_gelu: True})
return (_addmm_activation_default,)
```
Benchmark setup:
NGC pytorch 25.06 container
cublas version: 12.9.1.4
torch.compile ran with dynamic = False and max_autotune
H100
```
Testing with M=1024, N=1024, K=1024, dtype=bfloat16
============================================================
Average Time per Iteration (cublas): 0.0107 ms
Average Time per Iteration (torch compile): 0.0296 ms
============================================================
Testing with M=2048, N=2048, K=2048, dtype=bfloat16
============================================================
Average Time per Iteration (cublas): 0.0262 ms
Average Time per Iteration (torch compile): 0.0327 ms
============================================================
Testing with M=4096, N=4096, K=4096, dtype=bfloat16
============================================================
Average Time per Iteration (cublas): 0.1763 ms
Average Time per Iteration (torch compile): 0.2457 ms
============================================================
Testing with M=8192, N=8192, K=8192, dtype=bfloat16
============================================================
Average Time per Iteration (cublas): 1.5280 ms
Average Time per Iteration (torch compile): 1.9437 ms
```
A100
```
############################################################
Testing with dtype: float16
############################################################
============================================================
Testing with M=1024, N=1024, K=1024, dtype=float16
============================================================
Average Time per Iteration (cublas): 0.0313 ms
Average Time per Iteration (torch compile): 0.0643 ms
============================================================
Testing with M=2048, N=2048, K=2048, dtype=float16
============================================================
Average Time per Iteration (cublas): 0.1149 ms
Average Time per Iteration (torch compile): 0.1255 ms
============================================================
Testing with M=4096, N=4096, K=4096, dtype=float16
============================================================
Average Time per Iteration (cublas): 0.6297 ms
Average Time per Iteration (torch compile): 0.7547 ms
============================================================
Testing with M=8192, N=8192, K=8192, dtype=float16
============================================================
Average Time per Iteration (cublas): 4.3821 ms
Average Time per Iteration (torch compile): 5.0740 ms
```
Script
```py
import torch
torch.manual_seed(0)
warmup, numrun= 10, 100
sizes = [1024, 2048, 4096, 8192]
dtypes = [torch.float16, torch.bfloat16, torch.float32]
device = torch.device("cuda")
for dtype in dtypes:
dtype_name = str(dtype).split('.')[-1]
print(f"\n{'#'*60}")
print(f"Testing with dtype: {dtype_name}")
print(f"{'#'*60}")
for size in sizes:
M, N, K = size, size, size
print(f"\n{'='*60}")
print(f"Testing with M={M}, N={N}, K={K}, dtype={dtype_name}")
print(f"{'='*60}")
A = torch.randn(M, K, device=device, dtype=dtype)
B = torch.randn(K, N, device=device, dtype=dtype)
C = torch.randn(M, device=device, dtype=dtype)
def func1():
return torch._addmm_activation(C, A, B, use_gelu=True)
def func2():
return torch.nn.functional.gelu(torch.add(C, torch.mm(A, B)), approximate="tanh")
func2_compiled = torch.compile(
func2,
dynamic=False,
options={
"force_disable_caches": True,
"max_autotune": True,
"max_autotune_gemm": True,
"max_autotune_gemm_backends": "TRITON",
"autotune_fallback_to_aten": False,
}
)
for _ in range(warmup): func1()
torch.cuda.synchronize(device=device)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
total_time_ms = 0.0
start_event.record()
for _ in range(numrun): func1()
end_event.record()
torch.cuda.synchronize(device=device)
total_time_ms += start_event.elapsed_time(end_event)
avg_time_ms = total_time_ms / numrun
print(f"Average Time per Iteration (cublas):\t {avg_time_ms:.4f} ms")
for _ in range(warmup): func2_compiled()
torch.cuda.synchronize(device=device)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
total_time_ms = 0.0
start_event.record()
for _ in range(numrun): func2_compiled()
end_event.record()
torch.cuda.synchronize(device=device)
total_time_ms += start_event.elapsed_time(end_event)
avg_time_ms = total_time_ms / numrun
print(f"Average Time per Iteration (torch compile):\t {avg_time_ms:.4f} ms")
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158137
Approved by: https://github.com/eellison
Summary:
This change introduces a fallback path from `bmm` to `mm` when the batch dimension is `1`.
The motivation is to unlock specialized `mm` kernel paths (e.g., `decomposeK`, `persistent+TMA`, etc.) which often don't have `bmm` equivalents.
### Rationale
- **No regression:** On shapes where the fallback triggers, we see no performance loss.
- **Performance wins:** On select shapes (especially with large `K`), we observe measurable speedups by triggering `mm`-specific optimizations.
For example, on `bmm` shapes of the form `(1, H, K, H)` where `H ∈ {16, 32, 48, 64}` and `K ∈ {4096 ... 32768}`, we see an **average speedup of 10%**.
- **Prevalence in prod:** Internal workloads frequently emit `bmm` ops with `batch=1`, making this fallback broadly useful in practice.
Test Plan:
contbuild & OSS CI
Tests in test/inductor/test_torchinductor.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153572
Approved by: https://github.com/PaulZhang12, https://github.com/eellison
Now that torchinductor supports prologue fusion we can delete all the mixed mm code. When I benchmarked int8 weight only mm in the new path compared to int8mm in the old path in the [following benchmark](https://gist.github.com/eellison/46e321709572c11c077d0612cb3492b7) I got a 1.244x geomean speedup comparing Huggingface linear shapes with bias. There's a couple reasons for the speedup:
- prologue fusion is often unprofitable, even for int8 mm. because the current mixed mm benchmarking only compares triton_int8_mm vs (dtype_conversion + cublas), we miss out on scenarios where the triton template is profitable but the prologue fusion is not.
- similarly, we miss out on potential epilogue fusions like bias if we dispatch to the [fallback mixed mm](5006932cbc/torch/_inductor/kernel/mm.py (L750-L751)) that mixed_mm will dispatch to instead of the deferred epilogue tuning in current path.
It's possible some of the speedups would be smaller on larger models where the epilogue might get fused into a following kernel. Nonetheless, even if this is perf neutral it is worth landing for code deduplication.
The one kernel that is a little special and would not fall out of the prologue fusion is the uint4x2_mixed_mm kernel. it's still possible to generate with prologue fusion but not currently exactly as the current [impl](bd370c138a/torch/_inductor/kernel/unpack_mixed_mm.py (L43-L49)). But the current impl does not compare to a cublas baseline so I found that it is making things slower (35% slower on a not particularly big 1024, 1024, 1024 mm shape on h100). this should be fine to delete.
Future optimizations could include:
- cutlass prologue path
- making prologue fusion support the persistent tma based mm template. from @drisspg's experience this led to nice wins with fp8 but not as nice wins with bf16 mm. I think similarly, lower memory bandwidth int8 mm would benefit.
Differential Revision: [D70114858](https://our.internmc.facebook.com/intern/diff/D70114858)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147151
Approved by: https://github.com/drisspg, https://github.com/cpuhrsch
Now that torchinductor supports prologue fusion we can delete all the mixed mm code. When I benchmarked int8 weight only mm in the new path compared to int8mm in the old path in the [following benchmark](https://gist.github.com/eellison/46e321709572c11c077d0612cb3492b7) I got a 1.244x geomean speedup comparing Huggingface linear shapes with bias. There's a couple reasons for the speedup:
- prologue fusion is often unprofitable, even for int8 mm. because the current mixed mm benchmarking only compares triton_int8_mm vs (dtype_conversion + cublas), we miss out on scenarios where the triton template is profitable but the prologue fusion is not.
- similarly, we miss out on potential epilogue fusions like bias if we dispatch to the [fallback mixed mm](5006932cbc/torch/_inductor/kernel/mm.py (L750-L751)) that mixed_mm will dispatch to instead of the deferred epilogue tuning in current path.
It's possible some of the speedups would be smaller on larger models where the epilogue might get fused into a following kernel. Nonetheless, even if this is perf neutral it is worth landing for code deduplication.
The one kernel that is a little special and would not fall out of the prologue fusion is the uint4x2_mixed_mm kernel. it's still possible to generate with prologue fusion but not currently exactly as the current [impl](bd370c138a/torch/_inductor/kernel/unpack_mixed_mm.py (L43-L49)). But the current impl does not compare to a cublas baseline so I found that it is making things slower (35% slower on a not particularly big 1024, 1024, 1024 mm shape on h100). this should be fine to delete.
Future optimizations could include:
- cutlass prologue path
- making prologue fusion support the persistent tma based mm template. from @drisspg's experience this led to nice wins with fp8 but not as nice wins with bf16 mm. I think similarly, lower memory bandwidth int8 mm would benefit.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147151
Approved by: https://github.com/drisspg, https://github.com/cpuhrsch
vllm had an error when we were incorrectly stating two patterns are duplicates. See, comment inline:
For a particular generated pattern repr, store all the equivalent graphs that used to generate them. Because we ignore certain patterns in searching, but not in matching, use the graph to distinguish if two equivalent searches are actually different.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139321
Approved by: https://github.com/shunting314
vllm had an error when we were incorrectly stating two patterns are duplicates. See, comment inline:
For a particular generated pattern repr, store all the equivalent graphs that used to generate them. Because we ignore certain patterns in searching, but not in matching, use the graph to distinguish if two equivalent searches are actually different.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139321
Approved by: https://github.com/shunting314
Reuse partial reductions for complete reductions. We could expand this to more cover more types of reductions, although we'd have to be a bit more careful about keeping the intermediary, partial reduction in higher precision.
Just doing the ops which do not depend on a higher compute_dtype_precision for now to cover the relevant use case initially.
Fix for https://github.com/pytorch/pytorch/issues/136267. Longer term, we should make sure cooperative reductions fuse partial and complete reductions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143600
Approved by: https://github.com/vkuzo
This is sort of subtle - because we were doing `V.ops.mul` at binding time, we dont redispatch later when we invoke the epilogue. and then later running into assertion checking in pr above.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143127
Approved by: https://github.com/drisspg
ghstack dependencies: #143048
Using EC2 G6 instance, based on NVIDIA L4, added to scale config in https://github.com/pytorch/test-infra/pull/5376
To enable more balanced sharding, had to push 148ae19935
Added `@xfailIfSM89` to the following tests:
- test_fp8_pattern_2
- test_original_aten_preserved_split_addmm
- test_sparse_semi_structured_scaled_mm
- test_sparse_semi_structured_scaled_mm_fp8
- test_sparse_fp8fp8_mm
Increased tolerance to 2e-4 for `RNNTest.BidirectionalMultilayerGRU_CPU_vs_CUDA`
Skipped following inductor tests (that either flaky OOMs or timeouts):
- test_reduction_fn_std_float64
- test_reduction_fn_var_mean_float64
- test_multi_output_unbacked_custom_op
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140305
Approved by: https://github.com/wdvr, https://github.com/ZainRizvi
Fixes#137280
When we have multiple indexings for the same array as returned items in pattern replacement, we shouldn't ignore its indexing numbers. otherwise, we may create a wrong pattern_to_node mapping.
A unit test is added in this PR. In this unit test, the function `rms_pattern_static` is replaced with `rms_replacement_static` when called. The function `rms_pattern_static` calls two functionalized custom operators, `torch.ops.vllm.rms_norm.default` and `torch.ops.vllm.static_scaled_int8_quant.default`, and it returns at2[1] and at2[2] as outputs. The function `rms_replacement_static` calls one functionalized custom operator `torch.ops.vllm.fused_rms_norm_quant_static.default`, which returns two corresponding items.
Run `python test/inductor/test_pattern_matcher.py -k test_multioutput_register_replacement` to test. After set `TORCH_COMPILE_DEBUG` to 1, the final part of the `fx_graph_readable.py` is like the following.
```python
# File: /home/yhao/p9/pytorch/test/inductor/test_pattern_matcher.py:1673 in rms_pattern_static, code: at1 = auto_functionalized(
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.rms_norm.default, result = permute_1, input = convert_element_type, weight = convert_element_type_1, epsilon = 1e-06); permute_1 = convert_element_type = convert_element_type_1 = None
getitem_1: "bf16[5, 4]" = auto_functionalized[1]; auto_functionalized = None
# File: /home/yhao/p9/pytorch/test/inductor/test_pattern_matcher.py:1680 in rms_pattern_static, code: at2 = auto_functionalized(
auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.static_scaled_int8_quant.default, result = permute, input = getitem_1, scale = full_default, azp = None); permute = getitem_1 = full_default = None
getitem_3: "i8[5, 4]" = auto_functionalized_1[1]
getitem_4: "f32[1, 1]" = auto_functionalized_1[2]; auto_functionalized_1 = None
return (getitem_3, getitem_4)
```
This happens before pattern matching, so is it expected to call `static_scaled_int8_quant` and `rms_norm` and return auto_functionalized_1 as outputs.
However, for pytorch before this PR, the `fx_graph_transformed.py`, which is after pattern matching, has the following code.
```python
# File: /home/yhao/p9/pytorch/test/inductor/test_pattern_matcher.py:1748 in my_func_static, code: scale = torch.ones((1, 1))
full_default: "f32[1, 1]" = torch.ops.aten.full.default([1, 1], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
# No stacktrace found for following nodes
as_strided_default: "i8[20]" = torch.ops.aten.as_strided.default(permute, [20], [1], 0)
clone_default: "i8[20]" = torch.ops.aten.clone.default(as_strided_default); as_strided_default = None
as_strided_default_1: "i8[5, 4]" = torch.ops.aten.as_strided.default(clone_default, [5, 4], [4, 1], 0); clone_default = None
as_strided_default_2: "f32[1]" = torch.ops.aten.as_strided.default(full_default, [1], [1], 0)
clone_default_1: "f32[1]" = torch.ops.aten.clone.default(as_strided_default_2); as_strided_default_2 = None
as_strided_default_3: "f32[1, 1]" = torch.ops.aten.as_strided.default(clone_default_1, [1, 1], [1, 1], 0); clone_default_1 = None
static_scaled_int8_quant_default = torch.ops.vllm.static_scaled_int8_quant.default(as_strided_default_1, permute_1, as_strided_default_3); as_strided_default_1 = permute_1 = static_scaled_int8_quant_default = None
fused_rms_norm_quant_static_default = torch.ops.vllm.fused_rms_norm_quant_static.default(permute, convert_element_type, convert_element_type_1, full_default, None, 1e-06); convert_element_type = convert_element_type_1 = full_default = fused_rms_norm_quant_static_default = None
return (permute, as_strided_default_3)
```
Here, it returns `(permute, as_strided_default_3)` while `permute` is written by fused_rms_norm_quant_static and `as_strided_default_3` is written by `static_scaled_int8_quant`. This is wrong because in our expectation, the `static_scaled_int8_quant` should be removed since it is replaced with `fused_rms_norm_quant_static`. It is supposed to return `(permute, full_default)`.
The root cause is the following part. When we [generate patterns](5f4a21dc58/torch/_inductor/pattern_matcher.py (L1580)) with traced fx graph and call the following function, the indexing numbers' type int in traced graph are ignored in `ignore_types`. So, the final arguments of patterns for those two output items are like `(CallFunction(auto_functionalized,XXX)), *)`.
5f4a21dc58/torch/_inductor/pattern_matcher.py (L1839-L1847)
When we do pattern matching after we generated patterns in the following part, the `sorted(itertools.chain.from_iterable(nodes), reverse=True)` is `[getitem_4, getitem_3, getitem_1]`. The getitem_4's iteration is always FailedMatch because we always use the first element to do the pattern match here (it fails on different match functions before and after this PR, but the reason is always the indexing numbers issue)d4cdc09881/torch/_inductor/pattern_matcher.py (L848). However, when we do pattern matching for getitem_3, the child_match returns a match for getitem_3 again which is because the `*` pattern can match anything. Then the getitem_3's pattern matching returns a `[getitem_3, getitem_3]` as outputs which are wrong.
d4cdc09881/torch/_inductor/pattern_matcher.py (L856)d4cdc09881/torch/_inductor/pattern_matcher.py (L1750-L1774)
This PR doesn't ignore `int` type when we generate patterns for getitem functions because integer indexing numbers are important to them. Thus, the indexing information is kept in patterns, ensuring correct matchings. With this PR, the above `child_match` returns a match for getitem_4, and the final getitem_3's pattern matching returns the correct `[getitem_3, getitem_4]`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140193
Approved by: https://github.com/eellison
These are not artificial patterns I come up. They shows up in linear+CrossEntropyLoss graph.
Consider this snippet:
```
class LinearAndCEL(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(C, V)
self.ce = nn.CrossEntropyLoss()
def forward(self, x, y):
return self.ce(self.linear(x).view(B * T, V), y.view(-1))
```
`x` passed to `forward` is a 3D tensor of shape [B, T, C].
The `self.linear` will view x as [BxT, C] shape tensor first, do the matmul and produce a [BxT, V] tensor, and then view this output back to a 3D tensor with shape [B, T, V]. User code is gonna add another view op to convert the tensor shape to [B x T, V]. This generates a pair of redundant views . A pair of redundant permute happens in the backward part when we compute gradients.
The view ops makes it hard to chunk linear+CEL. When the view op breaks up the dimension being chunked, what should the chunker do (even if we merge those dimension again later)? Removing these pointless view pairs makes the chunker simpler. And I think it's in general nice to do.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139136
Approved by: https://github.com/Chillee, https://github.com/jansel
These are not artificial patterns I come up. They shows up in linear+CrossEntropyLoss graph.
Consider this snippet:
```
class LinearAndCEL(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(C, V)
self.ce = nn.CrossEntropyLoss()
def forward(self, x, y):
return self.ce(self.linear(x).view(B * T, V), y.view(-1))
```
`x` passed to `forward` is a 3D tensor of shape [B, T, C].
The `self.linear` will view x as [BxT, C] shape tensor first, do the matmul and produce a [BxT, V] tensor, and then view this output back to a 3D tensor with shape [B, T, V]. User code is gonna add another view op to convert the tensor shape to [B x T, V]. This generates a pair of redundant views . A pair of redundant permute happens in the backward part when we compute gradients.
The view ops makes it hard to chunk linear+CEL. When the view op breaks up the dimension being chunked, what should the chunker do (even if we merge those dimension again later)? Removing these pointless view pairs makes the chunker simpler. And I think it's in general nice to do.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139136
Approved by: https://github.com/Chillee, https://github.com/jansel
# Motivation
Fix https://github.com/pytorch/pytorch/issues/138577.
# Solution
1. All UTs in `test/inductor/test_compiled_optimizers.py` are fixed by https://github.com/pytorch/pytorch/pull/134170
2. UT in `test/inductor/test_pattern_matcher.py` is introduced by https://github.com/pytorch/pytorch/pull/138089, we will skip this UT due to the unsupported feature `max_autotune_gemm_backends:Triton`.
3. We have a new impl related to `histc`, so we remove the expected failure from `test/inductor/test_torchinductor_opinfo.py`
4. We support `avg_pool3d` for `fp16` data type, so we remove the expected failure from `test/inductor/test_torchinductor_opinfo.py`
5. CUDA-bias code is introduced by https://github.com/pytorch/pytorch/issues/138472, we just generalize it to `GPU_TYPE`.
# Additional Context
> Why update torch-xpu-ops commit pin here?
We have to update commit pin to avoid the build failure raised by the code change [C10_UNUSED](https://github.com/pytorch/pytorch/pull/138364).
> What does the feature of torch-xpu-ops update?
1. Add some foreach ops, like `unary ops` and `foreach_clamp_max` etc;
2. Add some maxpool ops forward and backward, like `averge_pool3d` and `max_pool3d`
3. Add some other ops, like `log_normal_`, `index_copy`, and `mode` etc;
4. fix build failure related to `C10_UNUSED`;
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138548
Approved by: https://github.com/malfet, https://github.com/EikanWang
replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example.
This also adds metadata for to register_replacement patterns, including pad_mm.
This fixes the issue by copying metadata from the original node to the replacement nodes. If there are multiple original nodes we skip the metadata transfer; so if you have a `add(z, mm(x, y))`, then the metadata won't be transferred right now.
Differential Revision: [D64480755](https://our.internmc.facebook.com/intern/diff/D64480755)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138089
Approved by: https://github.com/aakhundov
replace_by_example is used to implement some pattern-matching passes in inductor. Previously, replace_by_example would generate nodes with very little metadata. In particular, `meta["original_aten"]` would be lost; that meant that when generating triton kernel names, you could get empty names like `triton_tem_fused_0` if the input nodes to the fused kernel were the result of a pattern-matching pass that used replace_by_example.
This also adds metadata for to register_replacement patterns, including pad_mm.
This fixes the issue by copying metadata from the original node to the replacement nodes. If there are multiple original nodes we skip the metadata transfer; so if you have a `add(z, mm(x, y))`, then the metadata won't be transferred right now.
Differential Revision: [D64480755](https://our.internmc.facebook.com/intern/diff/D64480755)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138089
Approved by: https://github.com/aakhundov
When we are autotuning matmuls the aten.mm and the triton template choices take in an externally allocated tensor that can be a view into a pre-planned aten.cat. So long as the output shape and stride of the matmul matches the slice of the cat we're planning, we can realize the mm directly into the cat.
Discussion for reviewers:
It feels a little bit odd that in the existing code we set the output of aten.mm as [FlexibleLayout](bcac71517c/torch/_inductor/kernel/mm.py (L156)). While is this correct, it might lead to passing non performant output strides to cublas.. I guess this is better than a copy ? Not sure. We could also introduce a Layout that denotes a Fixed shape and stride which we control allocation
```
class AllocatedFixedLayout(FixedLayout)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132554
Approved by: https://github.com/jansel
When we are autotuning matmuls the aten.mm and the triton template choices take in an externally allocated tensor that can be a view into a pre-planned aten.cat. So long as the output shape and stride of the matmul matches the slice of the cat we're planning, we can realize the mm directly into the cat.
Discussion for reviewers:
It feels a little bit odd that in the existing code we set the output of aten.mm as [FlexibleLayout](bcac71517c/torch/_inductor/kernel/mm.py (L156)). While is this correct, it might lead to passing non performant output strides to cublas.. I guess this is better than a copy ? Not sure. We could also introduce a Layout that denotes a Fixed shape and stride which we control allocation
```
class AllocatedFixedLayout(FixedLayout)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132554
Approved by: https://github.com/jansel