As the title, this PR enables vectorization for the situation when the the index_expr depends on vectorized itervar. There are two cases here:
1. The vectorized itervar has constant stride in the index_expr. We vectorize the index_expr with `Vectorized<int32>::arange` for this case.
2. Otherwise, we load the index_expr vector in a non-contiguous way with a loop.
Below is the generated code for the first case from the test `test_concat_inner_vec`. Here `x1` is the index_expr and depends on the vectorized itervar `x1`. It has constant stride 1. We vectorized it with arange. We use `all_zero` to implement a short-cut for masks to avoid unnecessary execution of nested masked regions which are invalid.
Before:
```c++
#pragma omp for collapse(2)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(155L); x1+=static_cast<long>(1L))
{
auto tmp0 = c10::convert<long>(x1);
auto tmp1 = static_cast<long>(0);
auto tmp2 = tmp0 >= tmp1;
auto tmp3 = static_cast<long>(35);
auto tmp4 = tmp0 < tmp3;
auto tmp5 = [&]
{
auto tmp6 = in_ptr0[static_cast<long>(x1 + (35L*x0))];
return tmp6;
}
;
auto tmp7 = tmp4 ? tmp5() : static_cast<decltype(tmp5())>(0.0);
auto tmp8 = tmp0 >= tmp3;
auto tmp9 = static_cast<long>(155);
auto tmp10 = tmp0 < tmp9;
auto tmp11 = [&]
{
auto tmp12 = in_ptr1[static_cast<long>((-35L) + x1 + (120L*x0))];
return tmp12;
}
;
...
```
After:
```c++
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(32L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(144L); x1+=static_cast<long>(16L))
{
auto tmp0 = c10::convert<int>(x1);
auto tmp1 = at::vec::Vectorized<int32_t>::arange(tmp0, 1);
auto tmp2 = static_cast<int>(0);
auto tmp3 = at::vec::Vectorized<int>(tmp2);
auto tmp4 = to_float_mask(tmp1 >= tmp3);
auto tmp5 = static_cast<int>(35);
auto tmp6 = at::vec::Vectorized<int>(tmp5);
auto tmp7 = to_float_mask(tmp1 < tmp6);
auto tmp8 = [&]
{
auto tmp9 = masked_load(in_ptr0 + static_cast<long>(x1 + (35L*x0)), to_float_mask(tmp7));
return tmp9;
}
;
auto tmp10 =
[&]
{
if (all_zero(to_float_mask(tmp7)))
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(tmp8())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp8(), to_float_mask(tmp7));
}
}
()
;
...
```
Below is the generated code for the second case from the test case `test_expr_vec_non_contiguous`. Here, the index_expr is `31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L))` which depends on the vectorized itervar `x2` and doesn't have constant stride. So, we load the index_expr vector with a loop. (In fact, this can be further optimized since the index_expr is invariant with the data points in the range [x2, x2+16). So it can be regarded as a scalar. This will be optimized in the follow-up PR.) The code uses `vector_lane_mask_check` to implement the masked version of non-contiguous load.
Before:
```c++
#pragma omp for collapse(2)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
{
{
float tmp_acc0 = -std::numeric_limits<float>::infinity();
for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
{
auto tmp0 = c10::convert<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (c10::div_floor_integer(x2, 32L)));
auto tmp1 = static_cast<long>(2048);
auto tmp2 = tmp0 < tmp1;
auto tmp3 = [&]
{
auto tmp4 = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer(x1, 32L))) + (2048L*(static_cast<long>(x1) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
return tmp4;
}
;
auto tmp5 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
tmp_acc0 = max_propagate_nan(tmp_acc0, tmp5);
}
out_ptr0[static_cast<long>(x1 + (1024L*x0))] = tmp_acc0;
}
}
}
```
After:
```c++
#pragma omp for
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L))
{
{
#pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())})
float tmp_acc0 = -std::numeric_limits<float>::infinity();
at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());
for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
{
auto tmp0 =
[&]
{
__at_align__ std::array<int, 16> tmpbuf;
#pragma GCC unroll 16
for (long x1_inner = 0; x1_inner < 16; x1_inner++)
{
tmpbuf[x1_inner] = static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (c10::div_floor_integer(x2, 32L)));
}
return at::vec::Vectorized<int>::loadu(tmpbuf.data());
}
()
;
auto tmp1 = static_cast<int>(2048);
auto tmp2 = at::vec::Vectorized<int>(tmp1);
auto tmp3 = to_float_mask(tmp0 < tmp2);
auto tmp4 = [&]
{
auto tmp5 =
[&]
{
__at_align__ std::array<float, 16> tmpbuf;
#pragma GCC unroll 16
for (long x1_inner = 0; x1_inner < 16; x1_inner++)
{
if (vector_lane_mask_check(tmp3, x1_inner))
{
tmpbuf[x1_inner] = in_ptr0[static_cast<long>(31L + (63L*(c10::div_floor_integer((x1 + x1_inner), 32L))) + (2048L*(static_cast<long>((x1 + x1_inner)) % static_cast<long>(32L))) + (65536L*x0) + (c10::div_floor_integer(x2, 32L)))];
}
}
return at::vec::Vectorized<float>::loadu(tmpbuf.data());
}
()
;
return tmp5;
}
;
auto tmp6 =
[&]
{
if (all_zero(to_float_mask(tmp3)))
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(tmp4())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp4(), to_float_mask(tmp3));
}
}
()
;
tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp6);
}
tmp_acc0_vec.store(out_ptr0 + static_cast<long>(x1 + (1024L*x0)));
}
}
}
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114545
Approved by: https://github.com/lezcano
As the [RFC](https://github.com/pytorch/pytorch/issues/114856) mentions, this is the step 1 to add Intel GPU backend as an alternative inductor backend.
### Design
Typically, in order to integrate Intel GPU backend into Inductor, we need to inherit from `WrapperCodegen` and `TritonScheduling` and implement the corresponding subclasses respectively. However, since `WrapperCodegen` and `TritonScheduling` have some device-bias code generation **scattered** in their methods, overriding them in subclasses would introduce a lot of duplicated parent class code.
For example:
2a44034895/torch/_inductor/codegen/wrapper.py (L487)2a44034895/torch/_inductor/codegen/triton.py (L1996)
So we abstract the device-bias code scattered in WrapperCodegen and TritonScheduling and provide a unified interface "DeviceOpOverrides". This way, when integrating a new backend, we can maximize the reuse of `WrapperCodegen` and `TritonScheduling` code by inherit and implement this interface for device flexibility.
Currently the `DeviceOpOverrides` only cover Python wrapper code generation. We can futher extend it to cover Cpp wrapper code generation on demand.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116020
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel
Fixes#114310 and supersedes #114748.
There are two reasons why we have quite a few special cases for `round`:
1. `round` is actually two ops. With `ndigits=None` (default), `round` always returns an integer. When `ndigits` is an integer, the returned type is a float.
2. Although `round` takes two arguments, it is a unary function with a parameter rather than a binary one.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115259
Approved by: https://github.com/peterbell10, https://github.com/lezcano
This adds the `ir.Scan` node (currently only supported on CUDA) which re-uses the existing reduction kernel machinery to support different kinds of non-pointwise ops. Just like reductions it supports prologue and epilogue fusions and has both persistent and non-persistent kernel generation.
Currently this doesn't support the equivalent of `Reduction.create_multilayer` and will instead fall back to eager in those cases. This is because splitting into multiple kernel invocations ends up being far slower than cub's single kernel strategy which matches the performance of a copy kernel.
Fixes https://github.com/pytorch/pytorch/issues/93631
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106581
Approved by: https://github.com/lezcano, https://github.com/atalman
torch.split(x, l) fails when l's shape is the unbacked symint.
E.g. l =
y.tolist() makes l the unbacked shape, because l depends on the
data access of y. The downdtream call `SliceView.create()`
evaluates the shape even if the input shape is unbacked symint,
which brings up the bug.
Test Plan:
python test/inductor/test_unbacked_symints.py -k test_split_with_sizes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113406
Approved by: https://github.com/aakhundov, https://github.com/ezyang
torch.split(x, l) fails when l's shape is the unbacked symint.
E.g. l =
y.tolist() makes l the unbacked shape, because l depends on the
data access of y. The downdtream call `SliceView.create()`
evaluates the shape even if the input shape is unbacked symint,
which brings up the bug.
Test Plan:
python test/inductor/test_unbacked_symints.py -k test_split_with_sizes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113406
Approved by: https://github.com/aakhundov, https://github.com/ezyang
SymIntType is referenced by wrapper.py, so I added its .pyi definition.
I also added SymBoolType along the way for completeness.
The `insinstance` checks in wrapper.py reference torch.Type, which seems
to cause mypy to choke. Not entirely sure why; I've just added
type-ignore comments for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113411
Approved by: https://github.com/Skylion007
ghstack dependencies: #113409, #113410
This was originally @jansel's PR:
https://github.com/pytorch/pytorch/pull/102625, which I've built upon.
This diff implements static memory planning. It's disabled by default
while we examine its performance.
We use a greedy-by-size approach. For dynamic shapes, the sizes of the
example inputs are used as estimates when making planning decisions. We
generate expressions to calculate the actual memory offsets and sizes at
runtime when the values of the dynamic shapes are known. In order to
simplify these calculations, we have organized the allocations into a
tree that branches on space (address offsets) and time (live ranges).
Finally, we need to align these offsets, so we have added an `align`
sympy Expr to express these calculations.
Some limitations:
1. It is only enabled during inference for now. Enabling it for training
increases peak memory usage as we allocate all the memory needed for
training upfront, before freeing the memory allocated during
inference. We can probably address this by doing planning for both
the inference and training passes together.
2. It doesn't work with PyTorch Distributed, because kernels like
AllGatherIntoTensor codegen strings which do memory operations. We
can fix this down the line by having them emit MemoryPlanningLines
instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112178
Approved by: https://github.com/desertfire, https://github.com/jansel
This was originally @jansel's PR:
https://github.com/pytorch/pytorch/pull/102625, which I've built upon.
This diff implements static memory planning. It's disabled by default
while we examine its performance.
We use a greedy-by-size approach. For dynamic shapes, the sizes of the
example inputs are used as estimates when making planning decisions. We
generate expressions to calculate the actual memory offsets and sizes at
runtime when the values of the dynamic shapes are known. In order to
simplify these calculations, we have organized the allocations into a
tree that branches on space (address offsets) and time (live ranges).
Finally, we need to align these offsets, so we have added an `align`
sympy Expr to express these calculations.
Some limitations:
1. It is only enabled during inference for now. Enabling it for training
increases peak memory usage as we allocate all the memory needed for
training upfront, before freeing the memory allocated during
inference. We can probably address this by doing planning for both
the inference and training passes together.
2. It doesn't work with PyTorch Distributed, because kernels like
AllGatherIntoTensor codegen strings which do memory operations. We
can fix this down the line by having them emit MemoryPlanningLines
instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112178
Approved by: https://github.com/desertfire, https://github.com/jansel
**Summary**
Follow up https://github.com/pytorch/pytorch/pull/109893 which has issue in support of CPU as reported in https://github.com/pytorch/pytorch/issues/109897. This fix mainly includes 2 changes:
- Current implementation of `rename_indexing`
10c646295d/torch/_inductor/codegen/common.py (L1023) only add symbol name start with `s` or `ps` into `kernel.args.sizevars`. However, `Unbacked symint` will start as `i`, so we extend the implementation of `rename_indexing` to support symbol start with `i`.
- Currently, the internal loop index also name start as `i`. Since `i` has has been used as `Unbacked symint`, change the name to start with `x` which should align with trition.
**Test Plan**
```
python -u -m pytest -s -v test_torchinductor_dynamic_shapes.py -k test_bool_mask_nobreak
python -u -m pytest -s -v test_torchinductor_dynamic_shapes.py -k test_nonzero_size_factory_nobreak
python -u -m pytest -s -v test_torchinductor_dynamic_shapes.py -k test_item_zeros_nobreak
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110262
Approved by: https://github.com/ezyang, https://github.com/jgong5
In https://github.com/pytorch/pytorch/pull/107901, the CUDA event based
profiling is changed to profiler based profiling to avoid counting CPU-side
kernel launch overhead in final latency numbers. However, it turns out that
torch.profile() is significantly slower than CUDA event which affects model
compilation speed quite significantlly. This PR changes back to CUDA event
based profiling.
Follow-ups:
* Try CUDA event profiling with CUDAGraphs;
* Multi-GPU profiling;
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109338
Approved by: https://github.com/frank-wei
This adds the `ir.Scan` node (currently only supported on CUDA) which re-uses the existing reduction kernel machinery to support different kinds of non-pointwise ops. Just like reductions it supports prologue and epilogue fusions and has both persistent and non-persistent kernel generation.
Currently this doesn't support the equivalent of `Reduction.create_multilayer` and will instead fall back to eager in those cases. This is because splitting into multiple kernel invocations ends up being far slower than cub's single kernel strategy which matches the performance of a copy kernel.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106581
Approved by: https://github.com/lezcano, https://github.com/atalman
Inductor kernel codegen previously have the following side effect:
- in `Kernel.__exit__ `, we add local used buffers in graph.removed_buffers
- during codegen, we do memory allocation/free.
These cause doing multiple versions of codegen for the same kernel hard. The PR refactor the code to make kernel codegen not changing graph level states. After codegening a kernel, the graph level state is not changed so we can go on to codegen another version of the kernel if we want.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107617
Approved by: https://github.com/jansel
We'd like to benchmark fusion (either for autotuning or for gathering data to find some patterns that can guide optimizations). There is a deadlock here that prevents us from doing this: to benchmark fusion, we need do codegen before all the fusions are done. However currently codegen rely on xSchedulerNode.last_usage information to decide which buffers are not needed at all and thus don't even need to be allocated/written (Scheduler.removed_buffers tracks this). xSchedulerNode.last_usage information can only be computed once the order of all the nodes have been decided. But each fusion pass (`fuse_nodes_once`) can also change node orders. So we know the final node orders only after all the fusions have completed. That blocks us from doing codegen during fusion (before all fusion are done).
Here I just show the above with a chain of dependencies to make it easier to understand (a -> b means a depends on b, or b has to happen before a):
```
benchmark one fusion decision -> codegen -> xSchedulerNode.last_usage -> node order -> all fusions have completed
```
Actually we only need to decide if a buffer has only local usages (if yes, it's a candidate for removing). This can be decided if we know what are all the users for each buffer. We can avoid using xSchedulerNode.last_usage in this case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107320
Approved by: https://github.com/peterbell10, https://github.com/jansel
This replaces `var_unnormalized` reduction type with `welford_reduce` which takes the input data and outputs not just the variance, but also the mean and weights which account for the full welford accumulator state. Thus we can avoid re-computing the mean, and we now have enough information to create a multilayer reduction which I implement here by adding a second reduction type called `welford_combine` which reduces over all three inputs simultaneously.
Multi-layer support is particularly important as normalization operators like BatchNorm are being split in many timm models, which meant `var_unnormalized` had to fall back to two-pass variance calculation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104725
Approved by: https://github.com/lezcano
When removing an inplace buffer, we just mark it as ```REMOVED```, after removing some inplace buffer, and then if we mark a buffer as inplace buffer using the ```self.inplace_buffer.values()``` length to create a buffer name, there may have an issue which we may define a same inplace buffer name with existed in ```self.inplace_buffer.values()```:
before removing some inplace buffers, the ```self.inplace_buffers``` may be like:
```
{'buf0': InplacedBuffer(inner_name='in_out_ptr0', other_names=['buf0', 'buf2', 'buf4']), 'buf2': InplacedBuffer(inner_name='in_out_ptr0', other_names=['buf0', 'buf2', 'buf4']), 'buf4': InplacedBuffer(inner_name='in_out_ptr0', other_names=['buf0', 'buf2', 'buf4']), 'buf5': InplacedBuffer(inner_name='in_out_ptr1', other_names=['buf5', 'buf7', 'buf9']), 'buf7': InplacedBuffer(inner_name='in_out_ptr1', other_names=['buf5', 'buf7', 'buf9']), 'buf9': InplacedBuffer(inner_name='in_out_ptr1', other_names=['buf5', 'buf7', 'buf9']), 'buf12': InplacedBuffer(inner_name='in_out_ptr2', other_names=['buf12', 'buf13']), 'buf13': InplacedBuffer(inner_name='in_out_ptr2', other_names=['buf12', 'buf13']), 'buf17': InplacedBuffer(inner_name='in_out_ptr3', other_names=['buf17', 'buf19']), 'buf19': InplacedBuffer(inner_name='in_out_ptr3', other_names=['buf17', 'buf19']), 'buf21': InplacedBuffer(inner_name='in_out_ptr4', other_names=['buf21', 'buf25']), 'buf25': InplacedBuffer(inner_name='in_out_ptr4', other_names=['buf21', 'buf25']), 'buf20': InplacedBuffer(inner_name='in_out_ptr5', other_names=['buf20', 'buf26', 'buf31', 'buf32']), 'buf26': InplacedBuffer(inner_name='in_out_ptr5', other_names=['buf20', 'buf26', 'buf31', 'buf32']), 'buf31': InplacedBuffer(inner_name='in_out_ptr5', other_names=['buf20', 'buf26', 'buf31', 'buf32']), 'buf32': InplacedBuffer(inner_name='in_out_ptr5', other_names=['buf20', 'buf26', 'buf31', 'buf32'])}
```
After removing some inplace buffers, the ```self.inplace_buffers``` may be like:
```
{'buf0': InplacedBuffer(inner_name='in_out_ptr0', other_names=['buf0', 'buf2', 'buf4']), 'buf2': InplacedBuffer(inner_name='in_out_ptr0', other_names=['buf0', 'buf2', 'buf4']), 'buf4': InplacedBuffer(inner_name='in_out_ptr0', other_names=['buf0', 'buf2', 'buf4']), 'buf5': 'REMOVED', 'buf7': 'REMOVED', 'buf9': 'REMOVED', 'buf12': 'REMOVED', 'buf13': 'REMOVED', 'buf17': InplacedBuffer(inner_name='in_out_ptr3', other_names=['buf17', 'buf19']), 'buf19': InplacedBuffer(inner_name='in_out_ptr3', other_names=['buf17', 'buf19']), 'buf21': 'REMOVED', 'buf25': 'REMOVED', 'buf20': 'REMOVED', 'buf26': 'REMOVED', 'buf31': 'REMOVED', 'buf32': 'REMOVED', 'buf16': InplacedBuffer(inner_name='in_out_ptr6', other_names=['buf16', 'buf38']), 'buf38': InplacedBuffer(inner_name='in_out_ptr6', other_names=['buf16', 'buf38'])}
```
And then if we mark some buffer as inplace buffer and the buffer name will use ```in_out_ptr{len(unique(self.inplace_buffers.values()))}```, the buffer name may be ```in_out_ptr6``` even this name has existed in ```self.inplace_buffers```.
After this PR, we will change ```REMOVED``` to ```REMOVED{1, 2, 3..}``` which avoids defining a duplicate name. ```pyhpc_equation_of_state ``` of ```torchbench``` will work for CPU backend:
```python -m torch.backends.xeon.run_cpu --node_id 0 benchmarks/dynamo/torchbench.py --performance --inference --float32 -dcpu -n50 --inductor --freezing --no-skip --dashboard --only pyhpc_equation_of_state --cold_start_latency```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106852
Approved by: https://github.com/lezcano