This diff introduces a new separate logging of autotuning results,
with the intention of making the results analyzable, specifically
those for the new experimental Cutlass backend.
Results are logged as text files with one JSON document corresponding to a single benchmark result per line.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119004
Approved by: https://github.com/jansel
ghstack dependencies: #120620
Summary: Currently, when a custom (user-written) Triton kernel has a ReinterpretView argument in IR, we're always skipping the alignment checking for this argument when preparing the `signature_of` for the AOT compilation of the Triton kernel (via setting `TensorArg.check_alignment` to `False`). This is problematic for user-written kernels where, albeit reinterpreted, the argument of the Triton kernel (the data pointer) can still be aligned to 16. When we skip alignment checking, the performance of the AOT-compiled internal Triton kernels can degrade 2x--3x.
In this PR, we replace `TensorArg.check_alignment` by `TensorArg.offset`, in which we specify the offset of the `ReinterpretView.layout` relative to the underlying `ir.Buffer` (corresponding to the data pointer before reinterpretation). As the size and stride of the layout don't change the alignment properties, those can be skipped. Importantly, for `ReinterpretView` arguments of custom Triton kernels, we use `arg.data.get_name()` as the buffer name. That, together with the offset, is used to check the alignment.
Bonus: the namedtuples in `codegen/common.py` are refactored as `dataclass`es, with nicer type hints and default values (for the newly added `TensorArg.offset`).
Test Plan:
```
$ python test/inductor/test_aot_inductor.py -k test_triton_kernel_reinterpret_view
...
----------------------------------------------------------------------
Ran 6 tests in 27.952s
OK (skipped=4)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119649
Approved by: https://github.com/oulgen
This PR adds a new type of triton kernel in which data is persistent but the
reduction dimension is split over multiple blocks (up to the entire kernel).
though this is called a reduction dimension, in actuality we only support scans.
because of this limitation, i have to be able to block fusions of split scan
operations with reductions so chose to add a new `ir.SplitScan` node which
is identical but allows for differentiation in the scheduler.
The split scan kernel is also the first to require an additional workspace buffer
which is used to communicate between cuda blocks. this is slightly tricky as we
the exact scratch space requirement isn't known until the grid size is calculated.
here i workaround the issue by setting a minimum rblock size and always allocating
to the maximum possible grid size for a given input tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117992
Approved by: https://github.com/jansel
ghstack dependencies: #117991
I was just playing around with improving the typing of symbolic_shapes. The PR is not "complete" but I in particular wanted to get feedback on whether or not people liked making ValueRanges Generic; it seems that distinguishing if you have an Expr ValueRange or a SympyBoolean ValueRange is a lot of trouble for downstream. Using TypeGuard, we can perform refinements on the generic parameter inside methods, although we still have to cast back to ValueRange[T] due to https://github.com/python/mypy/issues/14425#issuecomment-1914852707
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118529
Approved by: https://github.com/Skylion007
Improvements to shape padding logic in torch/_inductor/pad_mm.py
These changes could lead up to 14% perf improvement for certain Meta internal models in experiments.
Most notably:
* 1.) Use aten.const_pad_nd operation to pad Tensors in a single op instead of using multiple steps involving intermediate buffers. This appears to be more performant than the previous logic, confirmed by Profiling & Benchmarking results ( Meta internal )
* 2.) Make many paddings unneccessary using explicitly transposed GEMM when either M or N dimension is properly aligned but the other is not, configurable via config.shape_pad_use_transpose (default: True).
* 3.) Enable shape padding for the Inductor CUDA / Cutlass backend for all GEMM ops where Cutlass would be enabled, without benchmarking in that case.
* Add config flag to always pad shapes (without benchmarking first), configurable via config.force_shape_pad (default: False )
* Added several new unit tests to ensure tensors are padded such that they meet all alignment requirements after padding.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118522
Approved by: https://github.com/jansel, https://github.com/eellison
I was just playing around with improving the typing of symbolic_shapes. The PR is not "complete" but I in particular wanted to get feedback on whether or not people liked making ValueRanges Generic; it seems that distinguishing if you have an Expr ValueRange or a SympyBoolean ValueRange is a lot of trouble for downstream. Using TypeGuard, we can perform refinements on the generic parameter inside methods, although we still have to cast back to ValueRange[T] due to https://github.com/python/mypy/issues/14425#issuecomment-1914852707
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118529
Approved by: https://github.com/Skylion007
dmypy silently ignores follow_imports = skip, so to get parity between
dmypy and mypy we have to suck it up and type: ignore all of the sympy
typing problems.
The suppressions were added automatically with the following script generated by GPT-4:
```
import re
# Read the error file
with open("error_file.txt", "r") as f:
errors = f.readlines()
# Parse the lines with errors and error types
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
# Insert ignore comments in the source files
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118469
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432, #118467, #118468
The original motivation for MYPYINDUCTOR was a faster type checking configuration that only checked a subset of files. With the removal of `follow_imports = ignore`, we are now able to use dmypy to do fast incremental typechecking, eliminating the need for this.
Perhaps erroneously, when I tee'ed up this PR I elected to delete the `follow_imports = skip` designations in the mypy-inductor.ini. This lead to a number of extra type error suppressions that I manually edited. You will need to review.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118432
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418
This diff introduce the following changes:
1. Fix sympy_subs to preserve integer and non-negative properties of replaced symbol when replacement is string
why is this needed?
I was compiling an expression:
x*abs(y) where y =-2
what happens is that this expression is passed as ``s1*abs(s0)`` then s0 is replaced to ks0 with a call to sympy_subs.
but sympy_subs used to replace s0 (integer=false, nonegative=false) with ks0(inetegr=true, nonegative = true)
resulting in ``x*abs(ks0) = x*ks0`` which is wrong
2. rename sympy_symbol to sympy_index_symbol to make it explicit.
3. add assertion that replaced expression is not passed as string but always a sympy expression.
Fixes https://github.com/pytorch/pytorch/issues/117757
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118150
Approved by: https://github.com/ezyang
For a persistent reduction, we generate 2 flavor of 'equivalant' kernels at the same time
- persistent reduction
- regular reduction
A MultiKernel wraps these 2 kernels and pick the one with better performance at runtime.
Here I talk more about implementation details:
- Inductor maintains states for generating kernels. E.g. the wrapper code. After we generate code for one kernel, we need restore the inductor state before we can generate the counterpart.
***There is one thing I need some comments from others***:
There is one tricky thing about kernel arguments. In general, inductor removes a buffer from the argument list if it's only used inside the kernel. But somehow a buffer removed by persistent reduction kernel may still be kept by the regular (non-persistent) reduction kernel because of some CSE invalidation rule. My current implementation avoid removing buffers if multi_kernel is enabled. This makes sure both flavors of reduction has consistent argument list. Another idea I have is, we generate the multi-kernel definition with the union of arguments from both sub-kernels. Let each sub-kernel pick the subset of arguments it wants. But this will make the code-gen or multi-kernel much complex.
I'm not sure if there is some easy and clean way to resolve this.
Testing command:
```
TORCHINDUCTOR_MULTI_KERNEL=1 TORCH_LOGS=+torch._inductor.graph TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 python benchmarks/dynamo/huggingface.py --backend inductor --amp --performance --only BertForMaskedLM --training
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103469
Approved by: https://github.com/jansel
As the title, this PR enables vectorization for the situation when the the index_expr depends on vectorized itervar. There are two cases here:
1. The vectorized itervar has constant stride in the index_expr. We vectorize the index_expr with `Vectorized<int32>::arange` for this case.
2. Otherwise, we load the index_expr vector in a non-contiguous way with a loop.
Below is the generated code for the first case from the test `test_concat_inner_vec`. Here `x1` is the index_expr and depends on the vectorized itervar `x1`. It has constant stride 1. We vectorized it with arange. We use `all_zero` to implement a short-cut for masks to avoid unnecessary execution of nested masked regions which are invalid.
Before:
```c++
#pragma omp for collapse(2)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(155L); x1+=static_cast<long>(1L))
{
auto tmp0 = c10::convert<long>(x1);
auto tmp1 = static_cast<long>(0);
auto tmp2 = tmp0 >= tmp1;
auto tmp3 = static_cast<long>(35);
auto tmp4 = tmp0 < tmp3;
auto tmp5 = [&]
{
auto tmp6 = in_ptr0[static_cast<long>(x1 + (35L*x0))];
return tmp6;
}
;
auto tmp7 = tmp4 ? tmp5() : static_cast<decltype(tmp5())>(0.0);
auto tmp8 = tmp0 >= tmp3;
auto tmp9 = static_cast<long>(155);
auto tmp10 = tmp0 < tmp9;
auto tmp11 = [&]
{
auto tmp12 = in_ptr1[static_cast<long>((-35L) + x1 + (120L*x0))];
return tmp12;
}
;
...
```
After:
```c++
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(144L); x1+=static_cast<long>(16L))
{
auto tmp0 = c10::convert<int>(x1);
auto tmp1 = at::vec::Vectorized<int32_t>::arange(tmp0, 1);
auto tmp2 = static_cast<int>(0);
auto tmp3 = at::vec::Vectorized<int>(tmp2);
auto tmp4 = to_float_mask(tmp1 >= tmp3);
auto tmp5 = static_cast<int>(35);
auto tmp6 = at::vec::Vectorized<int>(tmp5);
auto tmp7 = to_float_mask(tmp1 < tmp6);
auto tmp8 = [&]
{
auto tmp9 = masked_load(in_ptr0 + static_cast<long>(x1 + (35L*x0)), to_float_mask(tmp7));
return tmp9;
}
;
auto tmp10 =
[&]
{
if (all_zero(to_float_mask(tmp7)))
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(tmp8())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp8(), to_float_mask(tmp7));
}
}
()
;
...
```
Below is the generated code for the second case from the test case `test_expr_vec_non_contiguous`. Here, the index_expr is `31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L))` which depends on the vectorized itervar `x2` and doesn't have constant stride. So, we load the index_expr vector with a loop. (In fact, this can be further optimized since the index_expr is invariant with the data points in the range [x2, x2+16). So it can be regarded as a scalar. This will be optimized in the follow-up PR.) The code uses `vector_lane_mask_check` to implement the masked version of non-contiguous load.
Before:
```c++
#pragma omp for collapse(2)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
{
auto tmp0 = c10::convert<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L)));
auto tmp1 = static_cast<long>(2048);
auto tmp2 = tmp0 < tmp1;
auto tmp3 = [&]
{
auto tmp4 = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (2048L*(static_cast<long>(x1) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
return tmp4;
}
;
auto tmp5 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
tmp_acc0 = max_propagate_nan(tmp_acc0, tmp5);
}
out_ptr0[static_cast<long>(x1 + (1024L*x0))] = tmp_acc0;
}
}
}
```
After:
```c++
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L))
{
{
#pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())})
float tmp_acc0 = -std::numeric_limits<float>::infinity();
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
{
auto tmp0 =
[&]
{
__at_align__ std::array<int, 16> tmpbuf;
#pragma GCC unroll 16
for (long x1_inner = 0; x1_inner < 16; x1_inner++)
{
tmpbuf[x1_inner] = static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (c10::div_floor_integer(x2, 32L)));
}
return at::vec::Vectorized<int>::loadu(tmpbuf.data());
}
()
;
auto tmp1 = static_cast<int>(2048);
auto tmp2 = at::vec::Vectorized<int>(tmp1);
auto tmp3 = to_float_mask(tmp0 < tmp2);
auto tmp4 = [&]
{
auto tmp5 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x1_inner = 0; x1_inner < 16; x1_inner++)
{
if (vector_lane_mask_check(tmp3, x1_inner))
{
tmpbuf[x1_inner] = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (2048L*(static_cast<long>((x1 + x1_inner)) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
}
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data());
}
()
;
return tmp5;
}
;
auto tmp6 =
[&]
{
if (all_zero(to_float_mask(tmp3)))
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(tmp4())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp4(), to_float_mask(tmp3));
}
}
()
;
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp6);
}
tmp_acc0_vec.store(out_ptr0 + static_cast<long>(x1 + (1024L*x0)));
}
}
}
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114545
Approved by: https://github.com/lezcano
As the [RFC](https://github.com/pytorch/pytorch/issues/114856) mentions, this is the step 1 to add Intel GPU backend as an alternative inductor backend.
### Design
Typically, in order to integrate Intel GPU backend into Inductor, we need to inherit from `WrapperCodegen` and `TritonScheduling` and implement the corresponding subclasses respectively. However, since `WrapperCodegen` and `TritonScheduling` have some device-bias code generation **scattered** in their methods, overriding them in subclasses would introduce a lot of duplicated parent class code.
For example:
2a44034895/torch/_inductor/codegen/wrapper.py (L487)2a44034895/torch/_inductor/codegen/triton.py (L1996)
So we abstract the device-bias code scattered in WrapperCodegen and TritonScheduling and provide a unified interface "DeviceOpOverrides". This way, when integrating a new backend, we can maximize the reuse of `WrapperCodegen` and `TritonScheduling` code by inherit and implement this interface for device flexibility.
Currently the `DeviceOpOverrides` only cover Python wrapper code generation. We can futher extend it to cover Cpp wrapper code generation on demand.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116020
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel
Fixes#114310 and supersedes #114748.
There are two reasons why we have quite a few special cases for `round`:
1. `round` is actually two ops. With `ndigits=None` (default), `round` always returns an integer. When `ndigits` is an integer, the returned type is a float.
2. Although `round` takes two arguments, it is a unary function with a parameter rather than a binary one.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115259
Approved by: https://github.com/peterbell10, https://github.com/lezcano
This adds the `ir.Scan` node (currently only supported on CUDA) which re-uses the existing reduction kernel machinery to support different kinds of non-pointwise ops. Just like reductions it supports prologue and epilogue fusions and has both persistent and non-persistent kernel generation.
Currently this doesn't support the equivalent of `Reduction.create_multilayer` and will instead fall back to eager in those cases. This is because splitting into multiple kernel invocations ends up being far slower than cub's single kernel strategy which matches the performance of a copy kernel.
Fixes https://github.com/pytorch/pytorch/issues/93631
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106581
Approved by: https://github.com/lezcano, https://github.com/atalman
torch.split(x, l) fails when l's shape is the unbacked symint.
E.g. l =
y.tolist() makes l the unbacked shape, because l depends on the
data access of y. The downdtream call `SliceView.create()`
evaluates the shape even if the input shape is unbacked symint,
which brings up the bug.
Test Plan:
python test/inductor/test_unbacked_symints.py -k test_split_with_sizes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113406
Approved by: https://github.com/aakhundov, https://github.com/ezyang
torch.split(x, l) fails when l's shape is the unbacked symint.
E.g. l =
y.tolist() makes l the unbacked shape, because l depends on the
data access of y. The downdtream call `SliceView.create()`
evaluates the shape even if the input shape is unbacked symint,
which brings up the bug.
Test Plan:
python test/inductor/test_unbacked_symints.py -k test_split_with_sizes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113406
Approved by: https://github.com/aakhundov, https://github.com/ezyang