Using Philox4 as PRNG
Test plan (other that CI)
Run
```python
mport torch
from torch._inductor.utils import run_and_get_code
from contextlib import nullcontext
def foo(x):
return x * torch.randn_like(x)
foo_c = torch.compile(foo)
x = torch.ones(100, 100, device="mps")
y = foo_c(x)
print(y.mean().item(), y.std().item())
for i in range(25):
print(y[i].mean(), y[i].std())
```
And observe that printed values are close to 0 and 1
TODO: Better `randint` algorithm for large ranges
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145705
Approved by: https://github.com/dcci, https://github.com/jansel
Summary:
The bmm template generates code like this
```
template<bool accum>
void cpp_fused_bmm_66_micro_gemm(...) {
...
}
void single_thread_mm() {
...
cpp_fused_bmm_66_micro_gemm(...)
...
}
void threaded_mm() {
...
cpp_fused_bmm_66_micro_gemm(...)
...
}
void cpp_fused_bmm_66(...)
{
...
single_thread_mm(...);
...
threaded_mm(...);
...
}
```
The generated `fused_bmm` and `fused_bmm_microgemm` functions both have unique identifiers added to their names, but the `single_threaded_mm` and `threaded_mm` do not.
This diff adds unique identifies to those generated functions as well. The identifier is based on the kernel name. So for the example above we would generate a bmm template name like `cpp_fused_bmm_66_single_thread_mm()`.
Differential Revision: D68364772
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145303
Approved by: https://github.com/leslie-fang-intel, https://github.com/frost-intel, https://github.com/hl475
Triton commit 5220 adds tuple support in Triton (changing the indexing format in AttrsDescriptor) and commit 5512 replaces AttrsDescriptor with raw tuples. This PR fixes user-defined triton kernel handling (in most cases) for these new triton commits.
What this PR fixes:
* in triton_kernel_wrap.py, AST->TTIR parsing was to be updated for the new triton API
* ir.py - don't remove None args when using newer triton versions
* wrapper.py - update signature & constant handling
What this doesn't fix:
* correct None handling - I want to do a closer look at constant handling (including None, equal_to_1, and other constants).
* cpp wrapper (which needs to be fixed for both user-defined triton kernels and inductor-generated kernels)
test/inductor/test_triton_kernels.py passed on triton commit 74de6b46, with the exception of three tests (those shown here: 1374074098)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145348
Approved by: https://github.com/jansel
ghstack dependencies: #145051
Triton commit 5220 adds tuple support in Triton (changing the indexing format in AttrsDescriptor) and commit 5512 replaces AttrsDescriptor with raw tuples. This is an initial PR to add support for Triton versions after commit 5512 landed.
The main changes in 5220 and 5512 that need to be supported:
* AttrsDescriptor() gets replaced with a raw dict. The raw dict has the format `{(TUPLES): [["tt.divisibility", 16]]}`, where `(TUPLES)` is a tuple of indices, e.g. `((0,), (1,), (3,))` to indicate that args 0, 1, and 3 are divisible by 16. These indices are, themselves, represented as tuples to support nested inputs (e.g. an argument that's a tuple), but support for tuples is not implemented right now.
* "signature" changes: the signature now contains _all_ args, including constexpr and constant args.
* ASTSource now takes "constexprs" instead of "constants" - for example, equal-to-1 args are constants but not constexprs so we don't need to pass these args as "constants".
What this PR supports:
* Triton versions before Dec 9, 2024, and (partial support for) Triton versions after Jan 1, 2025
* (triton jan 1+) typical inductor-generated triton: updated AttrsDescriptor, signatures, constexpr/constant handling.
What this PR doesn't support (TODO in follow-up PRs):
* Triton versions between Dec 9, 2024 and before Jan 1, 2025
* (triton jan 1+) user-defined triton kernel support (this is implemented already in @anmyachev's patch)
* (triton jan 1+) triton_helper support (failing in triton codegen - needs investigation)
* (triton jan 1+) AOTI / cpp wrapper
thanks to @anmyachev for patches in https://github.com/intel/intel-xpu-backend-for-triton/blob/main/scripts/pytorch.patch, which contains most of these changes already
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145051
Approved by: https://github.com/jansel
Some context: Inplace padding is an optimization to do padding in place. E.g., if a tensor has size [2048, 2047] and stride [2048, 1]. When we need pad one extra element to the end of each row (e.g. during mm padding), we can just reuse the original tensor and do the padding inplace. This saves memory and bandwidth. One caveat for this optimization is, PyTorch does not allocate 2048 elements for the last row of the original tensor. It only allocate 2047 elements. So assuming the last row having enough space for 2048 elements may be wrong and cause OOB memory access (although I never see this happen maybe due to overallocation in the CUDACachingAllocation, this should better be fixed).
The fix is when we allocate the tensor, instead of doing something like:
```
buf0 = randn_strided([2048, 2047], [2048, 1])
```
we do some small overallocation
```
buf0 = randn_strided([2048, 2048], [2048, 1]).as_strided([2048, 2047], [2048, 1])
```
cpp_wrapper needs special handling since memory allocation goes thru different code path to python wrapper.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145325
Approved by: https://github.com/desertfire, https://github.com/jansel
ghstack dependencies: #140249
## Summary
Templated `int8xint8->int32` GEMM that uses AMX ISA (present on Intel Xeon Gen 4 & above). Any epilogues such as weight scale, activation scale, and bias are applied per output block in a fused manner .
Performs well for large values of `M` dimension (assuming canonical dimensions [`M, K`] and [`K, N`] for the activation & weight matrices'/tensors' sizes) when the activation is quantized per-token.
Also supports SmoothQuant GEMM pattern when activation is quantized per-tensor (scalar scale) or per-token (vector scale is applied as an epilogue in this case).
Also increased coverage of GEMM template for uint8 activation, int8 weight GEMM UTs for when the activation zero point is a 1D tensor (the existing implementation only accepted 0D tensors). However, some of such UTs would have to be explicitly enabled with `max-autotune` Inductor config.
## Performance data
The templated codegened fused GEMM with M=32, K=4096, N=14336 used in LLaMA3 exhibits more than 2x perf-gain compared to oneDNN qlinear + mul (for activation's scale) with 48 cores of one socket of Xeon SP 4th gen Platinum 8468 when per-token quantization is used.
For M=1, K=4096, N=14336, regardless of whether per-tensor quantization was used for activation or per-token, the perf gain was more than 3x.
Intel OpenMP & libtcmalloc had been preloaded. All cores used by the workload corresponded to distinct physical cores.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143187
Approved by: https://github.com/jansel, https://github.com/leslie-fang-intel, https://github.com/jgong5
Co-authored-by: Leslie Fang <leslie.fang@intel.com>
May be to be later reused from eager op as well
Also, didn't know that Metal already have type_traits
And use `metal::isunorderder(a, b)` instead of `metal::isnan(a + b)` is it is defined as function that is equivalent `a != a || b != b`, but I suspect it might have a best native implementation for the specific architecture
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145157
Approved by: https://github.com/dcci
This PR adds a heuristic to potentially fail the block pointer match early. Expressions like below take a long time to match using sympy (e.g. > 100 seconds)
```python
# torch._inductor.config.triton.use_block_ptr = True
# torch._inductor.config.triton.prefer_nd_tiling = True
# Expression from pytest -k test_max_pool2d1_dynamic_shapes_cuda:
((xindex//ps1))*((s2 - 3//2))**2 + 2*((xindex//ps1))*((s2 - 3//2)) + ((xindex//ps1)) + ((s2 - 3//2))*(ModularIndexing(xindex, ps0, ps0)) + (ModularIndexing(xindex, 1, ps0)) + (ModularIndexing(xindex, ps0, ps0))
```
Additionally, the heuristic for the number of dimensions based on the indexing expression is refined to only add dimensions for FloorDiv(index, denom) and ModularIndexing(index, denom, modulo) instead of including FloorDiv/ModularIndexing expressions that don't involve the index.
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144681
Approved by: https://github.com/jansel
**Summary**
In this PR, we enable the epilogues fusion and code generation for Grouped GEMM. Here are the high-level description of how we implement it.
**Fusion**
- The Grouped GEMM Template produces a `Template Buffer` with a `MultiOutputLayout` and a set of `MultiOutput Buffers`, where each buffer corresponds to a specific GEMM.
- During the initial round of fusion, the `Template Buffer` and all associated `MultiOutput Buffers` are fused into a `FusedSchedulerNode` by extending the existing fusion design.
- In subsequent fusion rounds, this `FusedSchedulerNode` can further fuse with its epilogues, following the original fusion design principles.
**Code Gen**
We maintain a list of epilogues and codegen it one by one.
- If any of the GEMM has bias, we create a extra `bias_add` epilogue and prepend it at first of the epilogue list.
- If any of the GEMM has no epilogue, we create a `to_bf16` copy epilogue and append it at last of the epilogue list.
**TestPlan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_epilogue
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143897
Approved by: https://github.com/jansel, https://github.com/jgong5
ghstack dependencies: #143796
**Summary**
Enable the CPP Grouped GEMM Fusion, lowering and Grouped GEMM Template following the RFC: https://github.com/pytorch/pytorch/issues/144012
- Support flexible number of GEMMs
- Share activation across GEMMs
- The Grouped GEMM Template supports independent activations
- However, the pattern matcher requires an anchor node, which is as the shared activation across GEMMs
- Each GEMM can have a unique weight but same sizes
- Each GEMM can have a unique bias or None
- Current PR does not yet support biases; this will be addressed in a follow-up epilogue fusion PR
- Each GEMM have its own epilogues
- Epilogue fusion is not yet supported in this PR and will be enabled in an upcoming follow-up epilogue fusion PR
**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_invalid
python -u -m pytest -s -v test/inductor/test_cpu_cpp_wrapper.py -k test_grouped_linear
```
**Example**
Here is the example and generated code
```
batch_size = 4
in_features = 512
out_features = 1024
dtype = torch.bfloat16
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear0 = torch.nn.Linear(in_features, out_features, bias=False)
self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)
def forward(self, x):
return self.linear0(x), self.linear1(x)
if __name__ == "__main__":
with torch.no_grad():
input = torch.randn(batch_size, in_features, dtype=dtype)
m = M(bias=bias).to(dtype=dtype).eval()
cm = torch.compile(m)
act_res = cm(input)
```
Generated Code: https://gist.github.com/leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16#file-grouped-gemm-generated-code-py
**Next Step**
- Support Epilogue fusion
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143796
Approved by: https://github.com/jgong5, https://github.com/jansel
`metal::isnan` is only defined for floats, so provide a generic wrapper
that is false for integral types
TODO: Figure out why type propagantion is not working (or should it?)
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144665
Approved by: https://github.com/dcci
Now error message looks as follows:
```
% python ../test/inductor/test_torchinductor.py -v -k test_cat_unbacked_2d_mps
test_cat_unbacked_2d_mps (__main__.GPUTests) ... inline_call []
stats [('calls_captured', 6)]
inductor [('extern_calls', 2), ('fxgraph_cache_miss', 1)]
aot_autograd [('total', 1), ('autograd_cache_bypass', 1), ('not_ok', 1)]
ERROR
======================================================================
ERROR: test_cat_unbacked_2d_mps (__main__.GPUTests)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/malfet/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 3126, in wrapper
method(*args, **kwargs)
File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 12254, in new_test
return value(self)
File "/Users/malfet/miniconda3/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 5885, in test_cat_unbacked_2d
self.common(
File "/Users/malfet/miniconda3/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 620, in check_model_gpu
check_model(
File "/Users/malfet/git/pytorch/pytorch/build/../test/inductor/test_torchinductor.py", line 461, in check_model
actual = run(*example_inputs, **kwargs)
File "/Users/malfet/git/pytorch/pytorch/torch/_dynamo/eval_frame.py", line 580, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 704, in _compile_fx_inner
raise InductorError(e, currentframe()).with_traceback(
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 689, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 1149, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/compile_fx.py", line 1064, in codegen_and_compile
compiled_fn = graph.compile_to_module().call
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/graph.py", line 1977, in compile_to_module
return self._compile_to_module()
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/graph.py", line 2018, in _compile_to_module
mod = PyCodeCache.load_by_key_path(
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/codecache.py", line 2768, in load_by_key_path
mod = _reload_python_module(key, path)
File "/Users/malfet/git/pytorch/pytorch/torch/_inductor/runtime/compile_tasks.py", line 51, in _reload_python_module
exec(code, mod.__dict__, mod.__dict__)
File "/var/folders/sc/2thx6_x95h7_h9qs8s48yh140000gn/T/tmpmyfz2ju8/lt/cltm34ognlgcc6oxoe6bexvtbwcdtdfgnkjj5miz7vhkemitacp7.py", line 40, in <module>
File "/var/folders/sc/2thx6_x95h7_h9qs8s48yh140000gn/T/tmpmyfz2ju8/lt/cltm34ognlgcc6oxoe6bexvtbwcdtdfgnkjj5miz7vhkemitacp7.py", line 32, in _compile_mps_shader
torch._inductor.exc.InductorError: SyntaxError: failed to compile
kernel void generated_kernel(
device float* out_ptr0,
constant float* in_ptr0,
uint xindex [[thread_position_in_grid]]
) {
long x1 = (xindex) / (3);
auto tmp0 = x1;
auto tmp1 = static_cast<long>(tmp0);
auto tmp2 = 0;
auto tmp3 = tmp1 >= tmp2;
auto tmp4 = 2;
auto tmp5 = tmp1 < tmp4;
long x0 = (xindex) % (3);
auto tmp6 = in_ptr0[x0 + 3*(x1)];
auto tmp7 = tmp5 ? tmp6 : 0.0;
auto tmp8 = tmp1 >= tmp4;
auto tmp9 = 2 + ks0;
auto tmp10 = static_cast<long>(tmp9);
auto tmp11 = tmp1 < tmp10;
auto tmp12 = 1.0;
auto tmp13 = tmp8 ? tmp12 : 0.0;
auto tmp14 = tmp5 ? tmp7 : tmp13;
long x2 = xindex;
out_ptr0[x2] = static_cast<float>(tmp14);
}
with program_source:18:25: error: use of undeclared identifier 'ks0'
auto tmp9 = 2 + ks0;
^
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
To execute this test, run the following from the base repo dir:
python test/inductor/test_torchinductor.py GPUTests.test_cat_unbacked_2d_mps
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
----------------------------------------------------------------------
Ran 1 test in 0.472s
FAILED (errors=1)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144649
Approved by: https://github.com/Skylion007, https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #144647, #144648
Just pass them as kernel arguments
After this change `pytest test/inductor/test_torchinduct.py -v -k _mps` reports 330 failed, 429 passed after and 335 failed, 424 passed before
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144662
Approved by: https://github.com/jansel