## Summary
This is re-land PR for https://github.com/pytorch/pytorch/pull/100706 to address the compilation latency performance regression.
## Root Cause
Regarding the C++/OpenMP backend, `codecache.pick_vec_isa()` to check vectorization ISA is a time-consuming and one-shot operation. It leads to taking a longer time to import `codegen.cpp` package because the `LoopLevel` of the package is decorated by `@dataclasses.dataclass` while the decorator will invoke `codecache.pick_vec_isa()` to initialize the `simd_nelements` of the `LoopLevel`.
c14cf312c9/torch/_inductor/codegen/cpp.py (L2883C53-L2883C53)
In terms of the Triton backend, it does not need to touch it. But we'd prefer to uniform the code. Therefore, the new design simultaneously registers `CpuScheduling` for CPU and `TritonScheduling` for Triton regardless of whether the current backend is Triton. It will bring additional overhead to the Triton backend.
```python
def init_backend_registration(self):
if get_scheduling_for_device("cpu") is None:
from .codegen.cpp import CppScheduling
register_backend_for_device("cpu", CppScheduling, WrapperCodeGen)
if get_scheduling_for_device("cuda") is None:
from .codegen.triton import TritonScheduling
register_backend_for_device("cuda", TritonScheduling, WrapperCodeGen)
```
## Solution
To resolve the compilation latency regression for the Triton backend, we changed the `LoopLevel` a little bit([new code changes](https://github.com/pytorch/pytorch/pull/106874/files#diff-5ab7b0235e2076a5fc6629ba0b109208940f5b94f5c13babc3e0f87cf4fcec82R2893-R2904)) by moving the `simd_nelements` to `__post_init__` and the compilation performance would be back.
## Compilation Latency Performance Result
We ran a single model benchmark and reproduced the compilation regression:
- Run `python benchmarks/dynamo/torchbench.py -dcuda --training --performance --inductor --only hf_Bart`
- W/ PR #100706, the compilation latency is about **57~58**
```
dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio,eager_peak_mem,dynamo_peak_mem,calls_captured,unique_graphs,graph_breaks,unique_graph_breaks
cuda,hf_Bart,4,1.556712,109.676554,57.055242,0.936330,5.760698,6.152422,642,1,8,7
cuda,hf_Bart,4,1.646658,109.621747,57.909817,0.936330,5.760698,6.152422,642,1,8,7
```
- W/O PR #100706, the compilation latency is about **46~47**
```
dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio,eager_peak_mem,dynamo_peak_mem,calls_captured,unique_graphs,graph_breaks,unique_graph_breaks
cuda,hf_Bart,4,1.599065,108.702480,47.490346,0.936330,5.760698,6.152422,642,1,8,7
cuda,hf_Bart,4,1.588419,108.431411,46.983041,0.936330,5.760698,6.152422,642,1,8,7
```
This PR fixed the compilation performance regression.
- W/ this PR #106874, the compilation latency is about **47~48**
```
dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio,eager_peak_mem,dynamo_peak_mem,calls_captured,unique_graphs,graph_breaks,unique_graph_breaks
cuda,hf_Bart,4,1.586261,108.149467,47.481058,0.936330,5.760698,6.152422,642,1,8,7
cuda,hf_Bart,4,1.758915,108.613899,47.925633,0.936330,5.760698,6.152422,642,1,8,7
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106874
Approved by: https://github.com/jansel
I found that for a tiled kernel for tensor with shape [a, b], we map 'a' with XBLOCK and 'b' with YBLOCK. However, 'a' actually should be the outer looper while 'b' corresponding to the inner loop. This order is picked by our loop ordering algorithm. Mapping 'a' with XBLOCK has the semantic like assigning 'a' to the inner loop instead.
For a simple 'A + B.t()' kernel, making the loop order consistent can brings 1.027x speedup ( 1.938ms -> 1.887ms speedup) . Here are the dump of kernels:
- before fix: https://gist.github.com/shunting314/4dacf73cf495cdd7e84dede7c3e0872d
- after fix (this one is done manually): https://gist.github.com/shunting314/441e8839d24e1878c313e539b1ebd551
I tried this on DistillGPT2 and found perf is neutral. But that because DistillGPT2 has a single tiled pointwise kernel in it's backward graph. Will check the dashboard.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106827
Approved by: https://github.com/jansel
`JITFunction._key_of` uses the value of the argument to distinguish between
i32 and i64, but this fails if the value is used in indexing calculations where
the value exceeds `INT_MAX`.
Instead, we should use `index_dtype` which means all indexing calculations are
performed in the same dtype.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106870
Approved by: https://github.com/lezcano
ghstack dependencies: #106626
This PR intends to extend Inductor to support the third-party backend that only focuses on the code generation just like what C++/OpenMP and Triton backend have done.
Currently, the generated code by Inductor contains two major parts. One is the kernel, and the other is the Python wrapper to glue the kernel. Therefore, the third-party backend needs to customize the two parts to generate its specific code.
- Python wrapper code generation
Inductor provides a `WrapperCodeGen` class to generate the Python wrapper code to glue the kernel. Therefore, it is straightforward for the third-party backend to generate the backend-specific Python wrapper code. It just needs to inherit the `WrapperCodeGen` class and purposely override the particular member functions.
- Kernel code generation
It is driven by different `Scheduling`. Hence, the third-party backend needs to provide a custom `Scheduling` for its specific kernel code generation. Currently, `CppScheduling` and `TritonScheduling` are for C++/OpenMP and Triton backend, respectively. But there is no common `Scheduling` class. Based on the scheduling invocation, this PR abstracts a common `Scheduling` class containing the following member functions.
- [group_fn](71c4becda7/torch/_inductor/scheduler.py (LL649C64-L649C64))
- [flush](71c4becda7/torch/_inductor/scheduler.py (L1150))
- [can_fuse_vertical](71c4becda7/torch/_inductor/scheduler.py (L1006))
- [can_fuse_horizontal](71c4becda7/torch/_inductor/scheduler.py (LL1008C45-L1008C64))
- [codegen_template](71c4becda7/torch/_inductor/scheduler.py (L1234)) _This function is only available for triton. If the third-party backend behaves as a sub-class of `TritonScheduling`, it can override it or reuse it._
- [codegen_nodes](71c4becda7/torch/_inductor/scheduler.py (L1234))
- [codegen_sync](71c4becda7/torch/_inductor/scheduler.py (LL1251C1-L1251C1)). _This function is only available for triton debug purpose. But it might also be useful for other computation devices. Therefore, we'd prefer to keep this function._
The third-party backend needs to inherit from the `Scheduling` class and implement these functions.
Regarding some other classes like `CppKernel` and `TritonKernel` for code generation, they are used by or part of the logic of either `Scheduling` or `WrapperCodeGen`. Hence, this PR does not define the interface and leaves the flexibility to the third-party backend. The third-party backend can decide to implement these classes from scratch or reuse them by inheriting and overriding them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100706
Approved by: https://github.com/jansel
Previously, when fusing a single node into a foreach op, the scheduler would iterate over each subnode and check if it can be fused, this PR adds a mapping so that the node to be fused with can be found more quickly by checking dependencies.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106008
Approved by: https://github.com/jansel
dependencies.py is used for tracking reads and writes, which is used for identifying dependencies between buffers: i.e. if buffer X reads buffer Y, then X depends on Y. ops.bucketize() reads from an offsets tensor, so we should track it in dependencies.py to correctly track dependencies. Since bucketize performs a binary search over the offsets tensor, the dependency is marked as a StarDep to indicate that the entire tensor is needed.
Use case: we find that jagged tensor dense_to_jagged ops - which use bucketize() to map jagged indices to dense indices - perform better if the bucketize() kernel is separated from the gather kernel. Previously, because bucketize() wasn't marked as reading anything, it would just get inlined.
Differential Revision: [D47422704](https://our.internmc.facebook.com/intern/diff/D47422704)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105102
Approved by: https://github.com/eellison
When running BertForMaskedLM , I found if I enable the kernel benchmark, essentially identical kernels will be defined once for each call site. The reason is the benchmark harness of those kernels uses different seed_offset for each invocation. We should be safe to just force seed_offset to be 0 so we can deduplicate identical kernel definitions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105099
Approved by: https://github.com/jansel
dependencies.py is used for tracking reads and writes, which is used for identifying dependencies between buffers: i.e. if buffer X reads buffer Y, then X depends on Y. ops.bucketize() reads from an offsets tensor, so we should track it in dependencies.py to correctly track dependencies. Since bucketize performs a binary search over the offsets tensor, the dependency is marked as a StarDep to indicate that the entire tensor is needed.
Use case: we find that jagged tensor dense_to_jagged ops - which use bucketize() to map jagged indices to dense indices - perform better if the bucketize() kernel is separated from the gather kernel. Previously, because bucketize() wasn't marked as reading anything, it would just get inlined.
Differential Revision: [D47422704](https://our.internmc.facebook.com/intern/diff/D47422704)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105102
Approved by: https://github.com/eellison
This allows `ops.minimum` and `ops.maximum` to be hoisted for indirect indexing
into direct indexing expressions. I also add support to the cpp printer for
Min/Max and fix the triton printer to support multi-argument Min/Max.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105020
Approved by: https://github.com/lezcano
Fixes#101684
Before this change, we get a float constant in triton
```
tmp0 = 0.2
```
which in triton IR becomes a float32 value
```
%cst_0 = arith.constant dense<2.000000e-01> : tensor<2xf32>
```
After, we get a tensor with explicit type
```
tmp0 = tl.full([1], 0.2, tl.float64)
```
which does generate a float64 in the triton IR
```
%cst_0 = arith.constant dense<2.000000e-01> : tensor<2xf64>
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104830
Approved by: https://github.com/lezcano
This is intended as a first step towards reductions with multiple outputs. This
also incidentally improves CSE of reductions under C++ codegen. For example,
```python
def fn(x):
return torch.argmin(x, dim=-1), torch.argmin(x, dim=-1)
```
Currently this generates two reductions, where the common load is CSEd
```cpp
for(long i1=static_cast<long>(0L); i1<static_cast<long>(10); i1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i1 + (10L*i0))];
if (tmp_acc0.value > tmp0) {
tmp_acc0.index = i1; tmp_acc0.value = tmp0;
}
if (tmp_acc1.value > tmp0) {
tmp_acc1.index = i1; tmp_acc1.value = tmp0;
}
}
auto tmp1 = tmp_acc0.index;
out_ptr0[static_cast<long>(i0)] = tmp1;
auto tmp2 = tmp_acc1.index;
out_ptr1[static_cast<long>(i0)] = tmp2;
```
but with this change it gets CSEd to a single accumulator
```cpp
for(long i1=static_cast<long>(0L); i1<static_cast<long>(10L); i1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i1 + (10L*i0))];
if (tmp_acc0.value > tmp0) {
tmp_acc0.index = i1; tmp_acc0.value = tmp0;
}
}
auto tmp1 = tmp_acc0.index;
out_ptr0[static_cast<long>(i0)] = tmp1;
out_ptr1[static_cast<long>(i0)] = tmp1;
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102737
Approved by: https://github.com/jgong5, https://github.com/lezcano
This is a bit inefficient because it computes the mean and throws it
away since ir.Reduction nodes only have 1 output. However, the mean
can at least be scheduled into the same loop as the variance now since
there is no data dependency. Thus we can take fewer passes over the
data.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102486
Approved by: https://github.com/lezcano, https://github.com/jansel
Background/problem: ops.bucketize needs to take a value `offsets_size`, which is the length of the `offsets` tensor. It is used, e.g., for the bounds of the binary search over the `offsets` tensor. The previous implementation of `ops.bucketize` expected `offsets_size` to be a CSEVariable; i.e. we'd pass `offsets_size = ops.index_expr(offsets.get_size()[0])` into `ops.bucketize()`. However, `ops.index_expr` will sometimes broadcast, turning the scalar `offsets_size` into a tensor. That caused errors, because [triton_helpers.bucketize_binary_search](a2fe6953bc/torch/_inductor/triton_helpers.py (L153-L155)) expects `offsets_size` to be a scalar. [Link - where the broadcasting happens](a2fe6953bc/torch/_inductor/codegen/triton.py (L1056))
Solution (this PR): Instead of passing `offsets_size` into `ops.bucketize` as a CSEVariable, pass in a sympy.Expr. Then, inside ops.bucketize, convert the sympy.Expr into a string that can be used in the generated triton code.
Differential Revision: [D47282413](https://our.internmc.facebook.com/intern/diff/D47282413)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104756
Approved by: https://github.com/jansel
In binary search triton implementations, (#104007) num_elements_per_warp=32 performs a lot better than larger values.
This PR adds an autotuning config option for this purpose. But since autotuning can affect compile times and this config isn't generally useful, we only try this config if bucketize is present. This is done by adding an extra field to triton_meta which is used by the pointwise autotuning
Performance: reused https://gist.github.com/davidberard98/066fd2115f59f5889ef61e4527d1eba5.
Before:
```
Eager 0.30088499188423157 ms
PT2 0.9296960234642029 ms
```
After:
```
Eager 0.3011910021305084 ms
PT2 0.22977299988269806 ms
```
Differential Revision: [D47237103](https://our.internmc.facebook.com/intern/diff/D47237103)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104456
Approved by: https://github.com/eellison
We recently have an optimization to squash x dimension for persistent reduction kernel when we are confident that XBLOCK will always be 1. We need update the code so that coordinate descent tuner does not tune XBLOCK in this case.
Test command. Fail before the fix and pass after.
```
TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 python benchmarks/dynamo/huggingface.py --backend inductor --amp --accuracy --only BertForMaskedLM --inference
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104692
Approved by: https://github.com/jansel
This is intended as a first step towards reductions with multiple outputs. This
also incidentally improves CSE of reductions under C++ codegen. For example,
```python
def fn(x):
return torch.argmin(x, dim=-1), torch.argmin(x, dim=-1)
```
Currently this generates two reductions, where the common load is CSEd
```cpp
for(long i1=static_cast<long>(0L); i1<static_cast<long>(10); i1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i1 + (10L*i0))];
if (tmp_acc0.value > tmp0) {
tmp_acc0.index = i1; tmp_acc0.value = tmp0;
}
if (tmp_acc1.value > tmp0) {
tmp_acc1.index = i1; tmp_acc1.value = tmp0;
}
}
auto tmp1 = tmp_acc0.index;
out_ptr0[static_cast<long>(i0)] = tmp1;
auto tmp2 = tmp_acc1.index;
out_ptr1[static_cast<long>(i0)] = tmp2;
```
but with this change it gets CSEd to a single accumulator
```cpp
for(long i1=static_cast<long>(0L); i1<static_cast<long>(10L); i1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i1 + (10L*i0))];
if (tmp_acc0.value > tmp0) {
tmp_acc0.index = i1; tmp_acc0.value = tmp0;
}
}
auto tmp1 = tmp_acc0.index;
out_ptr0[static_cast<long>(i0)] = tmp1;
out_ptr1[static_cast<long>(i0)] = tmp1;
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102737
Approved by: https://github.com/jgong5, https://github.com/lezcano
This is a bit inefficient because it computes the mean and throws it
away since ir.Reduction nodes only have 1 output. However, the mean
can at least be scheduled into the same loop as the variance now since
there is no data dependency. Thus we can take fewer passes over the
data.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102486
Approved by: https://github.com/lezcano, https://github.com/jansel
**TL;DR**: This PR is a first step in adding lowerings for torch.bucketize. It adds an initial lowering for this op - but because this implementation is not currently efficient, it registers the lowering for prims._inductor_bucketize. After we make the implementation more efficient, we'll remove prims._inductor_bucketize and add the lowering directly to torch.bucketize.
**Background - torch.bucketize**: torch.bucketize(values, boundaries, right=False): for an arbitrary tensor of values and a non-decreasing 1D tensor of boundaries that define buckets, it returns the index of the bucket that each of the values will fall in. e.g. for values [0, 1, 2, 3, 4] and boundaries [1, 3], it will return [0, 0, 1, 1, 2].
**Implementation**: This PR adds a new inductor op called "bucketize". In this PR it only has a triton implementation - for CPU it is a fallback. The triton implementation uses a binary search in `triton_helpers.py`. This PR also adds a new prim `_inductor_bucketize()` for testing purposes and adds lowering for this op.
~~**"right"**: The current behavior of the "right" kwarg in the inductor op is the opposite of the behavior of the torch op. "right" controls how the op treats a value that is equal to one of the boundary values. In the torch op, "right=True" means "if a value is equal to a boundary value, then put it in the bucket to the right". In the inductor op, "right=True" means "the right boundary of a bucket is closed". These are opposite. **I'm open to switching the behavior of the inductor op** - but I chose to implement this way because I think it makes more sense, and I think the torch.bucketize behavior may have been a mistake (it's the opposite of numpy.digitize).~~ Switched the behavior of the inductor bucketize op to match the torch op
* places where "right" means "if a value is equal to a boundary value, then put it in the bucket to the right" (i.e. current torch.bucketize behavior)
+ current torch.bucketize behavior
+ table in [torch.bucketize docs](https://pytorch.org/docs/stable/generated/torch.bucketize.html)
* places where "right" means "the right boundary of a bucket is closed":
+ the text description of [torch.bucketize docs](https://pytorch.org/docs/stable/generated/torch.bucketize.html) (observed in #91580)
+ [numpy.digitize](https://numpy.org/doc/stable/reference/generated/numpy.digitize.html) (which is basically the same op)
**Performance**: Benchmark script: "values" as a [16, 1024, 1024] float32 tensor and "boundaries" as a [1025] tensor (i.e. defining 1024 buckets).
As is:
```
Eager 0.30117499828338623 ms
PT2 0.9298200011253357 ms
```
But performance improves significantly if we add an additional pointwise autotuning config (WIP in #104456):
```
Eager 0.3015420138835907 ms
PT2 0.23028500378131866 ms
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104007
Approved by: https://github.com/jansel
I doubt there's much difference in performance, but this improves readability of
the generated code, e.g.
```python
tmp8 = triton_helpers.max2(tmp7, 1)[:, None]
```
becomes
```python
tmp8 = triton_helpers.any(tmp7, 1)[:, None]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103974
Approved by: https://github.com/lezcano
Fixes#103481
Normally triton tensors have shape `[XBLOCK, RBLOCK]`, or some variation where
the lengths are 1 but the number of dimensions is the same. The `no_x_dim`
change in addition to removing the x dimension, also removed the r dimension
from certain values such as the results of reductions and the `xindex` variable.
This fixes those two cases to correctly produce tensors of shape `[1]`,
equivalent to the old shape `[XBLOCK, 1]` with the x-dimension dropped.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103527
Approved by: https://github.com/ngimel
ValueRanges can't handle symbolic bounds. Be a bit more careful about detecting if you try to pass in expressions with free symbols, and fall back to "don't know" range if this occurs.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103470
Approved by: https://github.com/eellison
Originally, my goal for this PR was to remove the `dynamic_shapes` tests in torch/_dynamo/variables/builder.py. However, one thing lead to another, and it turns out that it was easiest to do all of the following in one go:
* Unconditionally allocate a ShapeEnv, no matter if dynamic_shapes is enabled or not (torch/_dynamo/output_graph.py). There is a small adjustment to export torch/_dynamo/eval_frame.py to account for the fact that a ShapeEnv always exists, even if you're not doing symbolic export.
* Remove dynamic_shapes test from unspec logic (torch/_dynamo/variables/builder.py), the original goal
* Specialize strides and storage offset if all sizes are dynamic (torch/fx/experimental/symbolic_shapes.py). This is required to deal with unconditional ShapeEnv: if a ShapeEnv exist, fake tensor-ification may choose to allocate symbols. The idea is that with `automatic_dynamic_shapes == False`, Dynamo should never request dynamic sizes, but this invariant was not upheld for nontrivial strides/offset.
The rest are just auxiliary fixups from the above:
* Workaround bug in FakeTensorProp where sometimes it doesn't return a FakeTensor (torch/fx/passes/fake_tensor_prop.py), see https://github.com/pytorch/pytorch/pull/103395 for follow up
* Make ShapeProp correctly handle int inputs (torch/fx/passes/shape_prop.py)
* Disable indexing strength reduction if `assume_static_by_default` is False (torch/_inductor/codegen/triton.py)
* Fix hf_T5_generate to NOT toggle `assume_static_by_default` if dynamic shapes is not enabled (benchmarks/dynamo/common.py); technically this is not necessary anymore but it's in for safety.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103302
Approved by: https://github.com/voznesenskym
This helps with kernels that make use of caching like mid-range softmax
which reads the data three times.
Selecting `eviction_policy=evict_first` in the last loop of the softmax
operation seems to give a 7-10% speed-up vs. selecting `evict_last` which
was the previous option. I'll put up some benchmarks soon™.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91316
Approved by: https://github.com/ngimel, https://github.com/jansel
Currently reduction bodies are duplicated in several different places.
This reduces duplication by `combine_fn` definition used in
`_unroll_reduction_fn` and using it in the triton codegen. For cpp
this also makes better use of `reduction_combine{,_vec}` by using them
to generate the `omp declare reduction` line and the `vec_reduce_all`
call.
For triton the only change is that that the combine step gets spread
over two lines, e.g. instead of:
```python
_tmp1 = tl.where(rmask & xmask, triton_helpers.maximum(_tmp1, tmp0), _tmp1)
```
we get
```python
tmp2 = triton_helpers.maximum(_tmp1, tmp0)
_tmp1 = tl.where(rmask & xmask, tmp2, _tmp1)
```
For cpp the only change is that inplace reduction operations are now written as
an out-of-place operation and an assignment, e.g. instead if
```cpp
omp_out += omp_in
```
we generate
```cpp
omp_out = omp_out + omp_in
```
Which is a purely cosmetic change
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99661
Approved by: https://github.com/lezcano, https://github.com/ngimel