This replaces `var_unnormalized` reduction type with `welford_reduce` which takes the input data and outputs not just the variance, but also the mean and weights which account for the full welford accumulator state. Thus we can avoid re-computing the mean, and we now have enough information to create a multilayer reduction which I implement here by adding a second reduction type called `welford_combine` which reduces over all three inputs simultaneously.
Multi-layer support is particularly important as normalization operators like BatchNorm are being split in many timm models, which meant `var_unnormalized` had to fall back to two-pass variance calculation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104725
Approved by: https://github.com/lezcano
When removing an inplace buffer, we just mark it as ```REMOVED```, after removing some inplace buffer, and then if we mark a buffer as inplace buffer using the ```self.inplace_buffer.values()``` length to create a buffer name, there may have an issue which we may define a same inplace buffer name with existed in ```self.inplace_buffer.values()```:
before removing some inplace buffers, the ```self.inplace_buffers``` may be like:
```
{'buf0': InplacedBuffer(inner_name='in_out_ptr0', other_names=['buf0', 'buf2', 'buf4']), 'buf2': InplacedBuffer(inner_name='in_out_ptr0', other_names=['buf0', 'buf2', 'buf4']), 'buf4': InplacedBuffer(inner_name='in_out_ptr0', other_names=['buf0', 'buf2', 'buf4']), 'buf5': InplacedBuffer(inner_name='in_out_ptr1', other_names=['buf5', 'buf7', 'buf9']), 'buf7': InplacedBuffer(inner_name='in_out_ptr1', other_names=['buf5', 'buf7', 'buf9']), 'buf9': InplacedBuffer(inner_name='in_out_ptr1', other_names=['buf5', 'buf7', 'buf9']), 'buf12': InplacedBuffer(inner_name='in_out_ptr2', other_names=['buf12', 'buf13']), 'buf13': InplacedBuffer(inner_name='in_out_ptr2', other_names=['buf12', 'buf13']), 'buf17': InplacedBuffer(inner_name='in_out_ptr3', other_names=['buf17', 'buf19']), 'buf19': InplacedBuffer(inner_name='in_out_ptr3', other_names=['buf17', 'buf19']), 'buf21': InplacedBuffer(inner_name='in_out_ptr4', other_names=['buf21', 'buf25']), 'buf25': InplacedBuffer(inner_name='in_out_ptr4', other_names=['buf21', 'buf25']), 'buf20': InplacedBuffer(inner_name='in_out_ptr5', other_names=['buf20', 'buf26', 'buf31', 'buf32']), 'buf26': InplacedBuffer(inner_name='in_out_ptr5', other_names=['buf20', 'buf26', 'buf31', 'buf32']), 'buf31': InplacedBuffer(inner_name='in_out_ptr5', other_names=['buf20', 'buf26', 'buf31', 'buf32']), 'buf32': InplacedBuffer(inner_name='in_out_ptr5', other_names=['buf20', 'buf26', 'buf31', 'buf32'])}
```
After removing some inplace buffers, the ```self.inplace_buffers``` may be like:
```
{'buf0': InplacedBuffer(inner_name='in_out_ptr0', other_names=['buf0', 'buf2', 'buf4']), 'buf2': InplacedBuffer(inner_name='in_out_ptr0', other_names=['buf0', 'buf2', 'buf4']), 'buf4': InplacedBuffer(inner_name='in_out_ptr0', other_names=['buf0', 'buf2', 'buf4']), 'buf5': 'REMOVED', 'buf7': 'REMOVED', 'buf9': 'REMOVED', 'buf12': 'REMOVED', 'buf13': 'REMOVED', 'buf17': InplacedBuffer(inner_name='in_out_ptr3', other_names=['buf17', 'buf19']), 'buf19': InplacedBuffer(inner_name='in_out_ptr3', other_names=['buf17', 'buf19']), 'buf21': 'REMOVED', 'buf25': 'REMOVED', 'buf20': 'REMOVED', 'buf26': 'REMOVED', 'buf31': 'REMOVED', 'buf32': 'REMOVED', 'buf16': InplacedBuffer(inner_name='in_out_ptr6', other_names=['buf16', 'buf38']), 'buf38': InplacedBuffer(inner_name='in_out_ptr6', other_names=['buf16', 'buf38'])}
```
And then if we mark some buffer as inplace buffer and the buffer name will use ```in_out_ptr{len(unique(self.inplace_buffers.values()))}```, the buffer name may be ```in_out_ptr6``` even this name has existed in ```self.inplace_buffers```.
After this PR, we will change ```REMOVED``` to ```REMOVED{1, 2, 3..}``` which avoids defining a duplicate name. ```pyhpc_equation_of_state ``` of ```torchbench``` will work for CPU backend:
```python -m torch.backends.xeon.run_cpu --node_id 0 benchmarks/dynamo/torchbench.py --performance --inference --float32 -dcpu -n50 --inductor --freezing --no-skip --dashboard --only pyhpc_equation_of_state --cold_start_latency```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106852
Approved by: https://github.com/lezcano
This PR aims to sort out the data type for `constant`.
The constant should be promoted to float https://github.com/pytorch/pytorch/pull/105440. So there are serval changes to do:
- Data type propagation should propagate constant node to `float` dtype if original dtype is `bfloat16`
- We do not need to insert `to_dtype` after the `constant` node, directly init an `fp32` constant is faster.
```
vectorized<bfloat16> tmp(value);
vectorized <float> tmp1 = cvt_bf16_fp32(tmp);
->
vectorized<float> tmp(value);
```
- move `constant` out of the list for `all operations can support bf16 without converting to fp32`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105827
Approved by: https://github.com/jgong5, https://github.com/jansel
This PR intends to extend Inductor to support the third-party backend that only focuses on the code generation just like what C++/OpenMP and Triton backend have done.
Currently, the generated code by Inductor contains two major parts. One is the kernel, and the other is the Python wrapper to glue the kernel. Therefore, the third-party backend needs to customize the two parts to generate its specific code.
- Python wrapper code generation
Inductor provides a `WrapperCodeGen` class to generate the Python wrapper code to glue the kernel. Therefore, it is straightforward for the third-party backend to generate the backend-specific Python wrapper code. It just needs to inherit the `WrapperCodeGen` class and purposely override the particular member functions.
- Kernel code generation
It is driven by different `Scheduling`. Hence, the third-party backend needs to provide a custom `Scheduling` for its specific kernel code generation. Currently, `CppScheduling` and `TritonScheduling` are for C++/OpenMP and Triton backend, respectively. But there is no common `Scheduling` class. Based on the scheduling invocation, this PR abstracts a common `Scheduling` class containing the following member functions.
- [group_fn](71c4becda7/torch/_inductor/scheduler.py (LL649C64-L649C64))
- [flush](71c4becda7/torch/_inductor/scheduler.py (L1150))
- [can_fuse_vertical](71c4becda7/torch/_inductor/scheduler.py (L1006))
- [can_fuse_horizontal](71c4becda7/torch/_inductor/scheduler.py (LL1008C45-L1008C64))
- [codegen_template](71c4becda7/torch/_inductor/scheduler.py (L1234)) _This function is only available for triton. If the third-party backend behaves as a sub-class of `TritonScheduling`, it can override it or reuse it._
- [codegen_nodes](71c4becda7/torch/_inductor/scheduler.py (L1234))
- [codegen_sync](71c4becda7/torch/_inductor/scheduler.py (LL1251C1-L1251C1)). _This function is only available for triton debug purpose. But it might also be useful for other computation devices. Therefore, we'd prefer to keep this function._
The third-party backend needs to inherit from the `Scheduling` class and implement these functions.
Regarding some other classes like `CppKernel` and `TritonKernel` for code generation, they are used by or part of the logic of either `Scheduling` or `WrapperCodeGen`. Hence, this PR does not define the interface and leaves the flexibility to the third-party backend. The third-party backend can decide to implement these classes from scratch or reuse them by inheriting and overriding them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100706
Approved by: https://github.com/jansel
This is intended as a first step towards reductions with multiple outputs. This
also incidentally improves CSE of reductions under C++ codegen. For example,
```python
def fn(x):
return torch.argmin(x, dim=-1), torch.argmin(x, dim=-1)
```
Currently this generates two reductions, where the common load is CSEd
```cpp
for(long i1=static_cast<long>(0L); i1<static_cast<long>(10); i1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i1 + (10L*i0))];
if (tmp_acc0.value > tmp0) {
tmp_acc0.index = i1; tmp_acc0.value = tmp0;
}
if (tmp_acc1.value > tmp0) {
tmp_acc1.index = i1; tmp_acc1.value = tmp0;
}
}
auto tmp1 = tmp_acc0.index;
out_ptr0[static_cast<long>(i0)] = tmp1;
auto tmp2 = tmp_acc1.index;
out_ptr1[static_cast<long>(i0)] = tmp2;
```
but with this change it gets CSEd to a single accumulator
```cpp
for(long i1=static_cast<long>(0L); i1<static_cast<long>(10L); i1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i1 + (10L*i0))];
if (tmp_acc0.value > tmp0) {
tmp_acc0.index = i1; tmp_acc0.value = tmp0;
}
}
auto tmp1 = tmp_acc0.index;
out_ptr0[static_cast<long>(i0)] = tmp1;
out_ptr1[static_cast<long>(i0)] = tmp1;
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102737
Approved by: https://github.com/jgong5, https://github.com/lezcano
Background/problem: ops.bucketize needs to take a value `offsets_size`, which is the length of the `offsets` tensor. It is used, e.g., for the bounds of the binary search over the `offsets` tensor. The previous implementation of `ops.bucketize` expected `offsets_size` to be a CSEVariable; i.e. we'd pass `offsets_size = ops.index_expr(offsets.get_size()[0])` into `ops.bucketize()`. However, `ops.index_expr` will sometimes broadcast, turning the scalar `offsets_size` into a tensor. That caused errors, because [triton_helpers.bucketize_binary_search](a2fe6953bc/torch/_inductor/triton_helpers.py (L153-L155)) expects `offsets_size` to be a scalar. [Link - where the broadcasting happens](a2fe6953bc/torch/_inductor/codegen/triton.py (L1056))
Solution (this PR): Instead of passing `offsets_size` into `ops.bucketize` as a CSEVariable, pass in a sympy.Expr. Then, inside ops.bucketize, convert the sympy.Expr into a string that can be used in the generated triton code.
Differential Revision: [D47282413](https://our.internmc.facebook.com/intern/diff/D47282413)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104756
Approved by: https://github.com/jansel
This is intended as a first step towards reductions with multiple outputs. This
also incidentally improves CSE of reductions under C++ codegen. For example,
```python
def fn(x):
return torch.argmin(x, dim=-1), torch.argmin(x, dim=-1)
```
Currently this generates two reductions, where the common load is CSEd
```cpp
for(long i1=static_cast<long>(0L); i1<static_cast<long>(10); i1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i1 + (10L*i0))];
if (tmp_acc0.value > tmp0) {
tmp_acc0.index = i1; tmp_acc0.value = tmp0;
}
if (tmp_acc1.value > tmp0) {
tmp_acc1.index = i1; tmp_acc1.value = tmp0;
}
}
auto tmp1 = tmp_acc0.index;
out_ptr0[static_cast<long>(i0)] = tmp1;
auto tmp2 = tmp_acc1.index;
out_ptr1[static_cast<long>(i0)] = tmp2;
```
but with this change it gets CSEd to a single accumulator
```cpp
for(long i1=static_cast<long>(0L); i1<static_cast<long>(10L); i1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i1 + (10L*i0))];
if (tmp_acc0.value > tmp0) {
tmp_acc0.index = i1; tmp_acc0.value = tmp0;
}
}
auto tmp1 = tmp_acc0.index;
out_ptr0[static_cast<long>(i0)] = tmp1;
out_ptr1[static_cast<long>(i0)] = tmp1;
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102737
Approved by: https://github.com/jgong5, https://github.com/lezcano
**TL;DR**: This PR is a first step in adding lowerings for torch.bucketize. It adds an initial lowering for this op - but because this implementation is not currently efficient, it registers the lowering for prims._inductor_bucketize. After we make the implementation more efficient, we'll remove prims._inductor_bucketize and add the lowering directly to torch.bucketize.
**Background - torch.bucketize**: torch.bucketize(values, boundaries, right=False): for an arbitrary tensor of values and a non-decreasing 1D tensor of boundaries that define buckets, it returns the index of the bucket that each of the values will fall in. e.g. for values [0, 1, 2, 3, 4] and boundaries [1, 3], it will return [0, 0, 1, 1, 2].
**Implementation**: This PR adds a new inductor op called "bucketize". In this PR it only has a triton implementation - for CPU it is a fallback. The triton implementation uses a binary search in `triton_helpers.py`. This PR also adds a new prim `_inductor_bucketize()` for testing purposes and adds lowering for this op.
~~**"right"**: The current behavior of the "right" kwarg in the inductor op is the opposite of the behavior of the torch op. "right" controls how the op treats a value that is equal to one of the boundary values. In the torch op, "right=True" means "if a value is equal to a boundary value, then put it in the bucket to the right". In the inductor op, "right=True" means "the right boundary of a bucket is closed". These are opposite. **I'm open to switching the behavior of the inductor op** - but I chose to implement this way because I think it makes more sense, and I think the torch.bucketize behavior may have been a mistake (it's the opposite of numpy.digitize).~~ Switched the behavior of the inductor bucketize op to match the torch op
* places where "right" means "if a value is equal to a boundary value, then put it in the bucket to the right" (i.e. current torch.bucketize behavior)
+ current torch.bucketize behavior
+ table in [torch.bucketize docs](https://pytorch.org/docs/stable/generated/torch.bucketize.html)
* places where "right" means "the right boundary of a bucket is closed":
+ the text description of [torch.bucketize docs](https://pytorch.org/docs/stable/generated/torch.bucketize.html) (observed in #91580)
+ [numpy.digitize](https://numpy.org/doc/stable/reference/generated/numpy.digitize.html) (which is basically the same op)
**Performance**: Benchmark script: "values" as a [16, 1024, 1024] float32 tensor and "boundaries" as a [1025] tensor (i.e. defining 1024 buckets).
As is:
```
Eager 0.30117499828338623 ms
PT2 0.9298200011253357 ms
```
But performance improves significantly if we add an additional pointwise autotuning config (WIP in #104456):
```
Eager 0.3015420138835907 ms
PT2 0.23028500378131866 ms
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104007
Approved by: https://github.com/jansel
**Summary**
Refactor the vectorization code generation of uint8 input data type. Previously, we combine the uint8 data load and uint8 to float data convert into one step as `load_uint8_as_float` and `store_float_as_uint8`. After refactor, we split them into 2 steps of load/store and data type convert to make the behavior same as BFloat16 data type .
The previous generated code is:
```
#pragma omp for
for(long i0=static_cast<long>(0L); i0<static_cast<long>(432L); i0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::load_uint8_as_float(in_ptr0 + static_cast<long>(i0));
auto tmp1 = (tmp0);
auto tmp2 = at::vec::Vectorized<float>(static_cast<float>(100.0));
auto tmp3 = tmp1 - tmp2;
auto tmp4 = at::vec::Vectorized<float>(static_cast<float>(0.01));
auto tmp5 = tmp3 * tmp4;
auto tmp6 = at::vec::clamp_min(tmp5, decltype(tmp5)(0));
auto tmp7 = tmp6 * tmp2;
auto tmp8 = tmp7.round();
auto tmp9 = tmp8 + tmp2;
auto tmp10 = at::vec::Vectorized<float>(static_cast<float>(0.0));
auto tmp11 = at::vec::maximum(tmp9, tmp10);
auto tmp12 = at::vec::Vectorized<float>(static_cast<float>(255.0));
auto tmp13 = at::vec::minimum(tmp11, tmp12);
auto tmp14 = (tmp13);
at::vec::store_float_as_uint8(tmp14, out_ptr0 + static_cast<long>(i0));
}
```
After this PR, the generated code is:
```
#pragma omp for
for(long i0=static_cast<long>(0L); i0<static_cast<long>(432L); i0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<uint8_t>::loadu(in_ptr0 + static_cast<long>(i0), 16);
auto tmp1 = cvt_uint8_to_fp32_with_same_elem_num(tmp0);
auto tmp2 = at::vec::Vectorized<float>(static_cast<float>(100.0));
auto tmp3 = tmp1 - tmp2;
auto tmp4 = at::vec::Vectorized<float>(static_cast<float>(0.01));
auto tmp5 = tmp3 * tmp4;
auto tmp6 = at::vec::clamp_min(tmp5, decltype(tmp5)(0));
auto tmp7 = tmp6 * tmp2;
auto tmp8 = tmp7.round();
auto tmp9 = tmp8 + tmp2;
auto tmp10 = at::vec::Vectorized<float>(static_cast<float>(0.0));
auto tmp11 = at::vec::maximum(tmp9, tmp10);
auto tmp12 = at::vec::Vectorized<float>(static_cast<float>(255.0));
auto tmp13 = at::vec::minimum(tmp11, tmp12);
auto tmp14 = cvt_fp32_to_uint8(tmp13);
tmp14.store(out_ptr0 + static_cast<long>(i0), 16);
}
```
**Test Plan**
```
python -m pytest test_cpu_repro.py -k test_decomposed_dequant_relu_quant
python -m pytest test_cpu_repro.py -k test_tile2d_load_decomposed_dequant_add_relu_quant
python -m pytest test_cpu_repro.py -k test_tile2d_store_channel_shuffle_cl_quant_output
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104075
Approved by: https://github.com/jgong5, https://github.com/jansel
This PR decouples the logic necessary to compute bounds on variables
from the logic that uses this info to perform the strenght analysis on
int64 variables. While doing so, it tries to minimize the number of
attributes of the class in favour of local variables.
This class is now accessible from any `LoopBody` object.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100549
Approved by: https://github.com/eellison
Fix https://github.com/pytorch/pytorch/issues/100830.
For the inplace node, there will be a `copy_` generated and the `copy_` will be `realized` as a `scheduler buffer` since it is a mutation. This `scheduler buffer` is a memory copy but after fusing with the previous buffer, it will not be a memory copy only buffers.
This PR solves the issue by removing `load_bf16_as_fp32` and `store_bf16_from_fp32`. Instead, enable fp32/bf16 vec conversion in `to_dtype`. Then we always store bf16.
```python
import torch
import torch.nn as nn
torch.manual_seed(420)
from torch._inductor import config
x = torch.randn(1, 18, dtype=torch.bfloat16)
class ExampleModel(nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.relu = nn.ReLU(inplace=True) # nn.ReLU(inplace=False)
def forward(self, input1):
out = self.relu(input1)
# input1.copy_(out)
return out
func = ExampleModel()
with torch.no_grad():
func.train(False)
res1 = func(x) # without jit
print(res1)
jit_func = torch.compile(func)
res2 = jit_func(x)
print(res2)
```
Generated code without this PR: (`tm3` store is wrong, `tmp3` is `float` while `out_ptr1` is `bf16`)
```
auto tmp0 = load_bf16_as_float(out_ptr1 + static_cast<long>(i0));
auto tmp1 = (tmp0);
auto tmp2 = at::vec::clamp_min(tmp1, decltype(tmp1)(0));
auto tmp3 = (tmp2);
store_float_as_bf16(out_ptr0 + static_cast<long>(i0), tmp3);
tmp3.store(out_ptr1 + static_cast<long>(i0), 16);
```
Generated code with this PR:
```
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(out_ptr1 + static_cast<long>(i0), 16);
auto tmp1 = cvt_bf16_to_fp32(tmp0);
auto tmp2 = at::vec::clamp_min(tmp1, decltype(tmp1)(0));
auto tmp3 = cvt_fp32_to_bf16(tmp2);
tmp3.store(out_ptr0 + static_cast<long>(i0), 16);
tmp3.store(out_ptr1 + static_cast<long>(i0), 16);
```
This PR also fixed the data type propagation for `masked_subblock`.
Before the masked_subblock's dtype is propagated by its input which is wrong.
```
opcode name target args kwargs
----------- --------- --------- -------------------------- --------
call_module masked_subblock1 masked_subblock1 (and__2, -inf)
```
Now we propagated it by subblock with the same name:
```
# graph for body.subblocks['masked_subblock1']
opcode name target args kwargs
----------- --------- --------- -------------------------- --------
placeholder ops ops () {}
call_module get_index get_index ('index2',) {}
call_method load load (ops, 'arg0_1', get_index) {}
call_method to_dtype to_dtype (ops, load, torch.float32) {}
output output output (to_dtype,) {}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101042
Approved by: https://github.com/jgong5, https://github.com/jansel
Currently if we have an inplaced buffer that's completely internal to a fused kernel and thus doesn't need to be allocated, we are still allocating it and sending unused argument to a kernel, because our analysis for removing buffers treats it separately (assuming that either original or mutated value are still needed).
This PR extends buffer removal to inplaced buffers that can be removed.
Generated kernel for e.g. ln changes from
```
def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
```
where in_out_ptr0 is unused in the kernel to
```
def triton_(in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
```
and corresponding allocation/reuse lines in the wrapper are removed.
The `in_out_ptr1` is also mislabeled - it's not `in_out`, it's only written to, but this PR doesn't fix it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102289
Approved by: https://github.com/jansel
## Issue description
The PR https://github.com/pytorch/pytorch/pull/100064 introduces a new RNG operation process. However, it causes every `randint` to load a separate random seed by default. TorchInductor generates a buffer to store all necessary random seeds and places the offsets as constant values in the subsequent compute buffers. In ir_pre_fusion generated by TorchInductor, some buffers only differ by one line, which is the load random seed with the corresponding offset. Subsequently, the codegen generates Triton kernels following the same rule. Finally, in the output_code.py, some Triton kernels only differ by one line, meaning that redundant kernels are being generated.
## Solution
This PR captures the seed offset and adds it to the existing `self.sizevars` structure. It generates variable names as placeholders, allowing the code wrapper to pass the offset as an argument to the kernels. I've also modified the divisible_by_16 check to exclude this argument.
This PR reduces the number of generated kernels from 50 to 17 for BertForMaskedLM forward.
According to tests on my own environment, the compilation time of attention_is_all_you_need_pytorch has been reduced from 94s to 66s. The speedup remains largely unchanged, at 1.37X.
The following is a comparison for a simple example.
Before:
```
triton_poi_fused_0 = async_compile.triton('triton_', '''
...
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
...
tmp0 = tl.load(in_ptr0 + 0)
tmp1 = x0
tmp2 = triton_helpers.randint64(tmp0, (tmp1).to(tl.uint32), 0, 10)
triton_poi_fused_1 = async_compile.triton('triton_', '''
...
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
...
tmp0 = tl.load(in_ptr0 + 1)
tmp1 = x0
tmp2 = triton_helpers.randint64(tmp0, (tmp1).to(tl.uint32), 0, 10)
...''')
def call(args):
triton_poi_fused_0.run(buf0, buf1, 1024, grid=grid(1024), stream=stream0)
triton_poi_fused_1.run(buf0, buf2, 1024, grid=grid(1024), stream=stream0)
```
After:
```
triton_poi_fused_0 = async_compile.triton('triton_', '''
...
def triton_(in_ptr0, out_ptr0, load_seed_offset, xnumel, XBLOCK : tl.constexpr):
...
tmp0 = tl.load(in_ptr0 + load_seed_offset)
tmp1 = x0
tmp2 = triton_helpers.randint64(tmp0, (tmp1).to(tl.uint32), 0, 10)
....
def call(args):
triton_poi_fused_0.run(buf0, buf1, 0, 1024, grid=grid(1024), stream=stream0)
triton_poi_fused_0.run(buf0, buf2, 1, 1024, grid=grid(1024), stream=stream0)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102104
Approved by: https://github.com/jansel, https://github.com/ngimel
This wraps `ops` into an `OpsWrapper` object which wraps any returned
IR values into an `OpsValue` instance. This allows magic methods to
be implemented and means lowerings can write mathematical expressions much more
fluently. So instead of
```python
ops.add(ops.mul(ops.mul(ops.sub(ops.mul(_Ap2, x), _Ap3), x), x), _1)
```
we can write
```python
(_Ap2 * x - _Ap3) * x * x + _1
```
And it will translate to the equivalent `ops` calls.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101076
Approved by: https://github.com/lezcano, https://github.com/ngimel
Fixes#100831, fixes#100878
Previously `gen_assert_indirect_indexing` was only called on the index
expressions passed to `ops.load` and `ops.store` which means if the
variable is optimized out during lowering, we never generate the
assert. This instead makes `ops.indirect_indexing` eagerly generate
the assert statement, whether or not it will be used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100895
Approved by: https://github.com/lezcano, https://github.com/ngimel
**Summary**
Since current quantization flow has not decomposed quant/dequant into prim ops, in this PR
- We enable the quant/dequant decomposition as lowering inside inductor.
- For the `decomposed.quant/dequant.tensor` overload, there are loading of scalar tensor of `zero point` and `scale`, we need to enable the vec code gen for these op overloads.
- Minor change as adding `is_load_uint8_as_float` and `is_store_float_as_uint8` default value `False` into `OptimizationContext`.
**TestPlan**
```
cd test/inductor && python -m pytest test_cpu_repro.py -k test_dequant_quant_lowering
```
co-author with @Xia-Weiwen
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99131
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel
This PR also adds a way to CSE statements (not only assignments).
The tests follow the pattern from https://github.com/openai/triton/pull/1143
They take a fair amount of time to run (90s in my box). If we wanted to
improve this, we could avoid testing the `ndim == 3` case.
Changes like this one make me hope that we get to clean the amount of
lowerings we have at some point...
Generated code for `x[y]` with `x.shape == (3, 2, 4), y.ndim == 1`:
With `dynamic=False`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < 3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < 3")
tmp1 = tl.load(in_ptr1 + (x0 + (8*tmp0)), xmask)
```
With `dynamic=True`:
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3")
tmp1 = tl.load(in_ptr1 + (x0 + (ks1*ks2*tmp0)), xmask)
```
Generated code for `x[y+1, y+1]` with `x.shape == (3, 2, 4), y.ndim == (3, 3)`:
With `dynamic=False` (note how it folds the two upper bounds to `min(3, 2) == 2`
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = 1
tmp2 = tmp0 + tmp1
tl.device_assert(((0 <= tmp2) & (tmp2 < 2)) | (~xmask), f"index out of bounds: 0 <= tmp2 < 2")
tmp3 = tl.load(in_ptr1 + (x0 + (12*tmp2)), xmask)
```
With `dynamic=True`:
```python
tl.device_assert(((0 <= tmp2) & (tmp2 < min(ks2, k1))) | (~xmask), f"index out of bounds: 0 <= tmp2 < min(ks2, ks1)")
```
The same works when the CSE'd variable appears 3 or more times, but then it generates `min(ks0, min(ks1, ks2))`
Generated code for `x[y] = z` with `x.ndim = 3`, `y.ndim = 1` and dynamic shapes
```python
tmp0 = tl.load(in_ptr0 + (x1), xmask)
tmp1 = tl.load(in_ptr1 + (x2), xmask)
tl.device_assert(((0 <= tmp0) & (tmp0 < ks3)) | (~xmask), f"index out of bounds: 0 <= tmp0 < ks3")
tl.store(out_ptr0 + (x0 + (ks1*ks2*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
```
Fixes https://github.com/pytorch/pytorch/issues/93538
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98590
Approved by: https://github.com/ngimel
This makes only a cosmetic change to the generated code, but means
triton's broadcasting logic doesn't leak out into the CSE class.
Before:
```python
tmp5_load = tl.load(in_ptr1 + (0))
tmp5 = tl.broadcast_to(tmp5_load, [XBLOCK])
```
After:
```python
tmp5 = tl.load(in_ptr1 + (0))
tmp6 = tl.broadcast_to(tmp5, [XBLOCK])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98304
Approved by: https://github.com/ngimel
OK, so this PR used to be about reducing the number of constants we specialize on, but it turns out that unspecialization was ~essentially never used (because we still constant specialized way too aggressively) and I ended up having to fix a bunch of issues to actually get tests to pass. So this PR is now "make int unspecialization actually work". As part of this, I have to turn off unspecialization by default, as there are still latent bugs in inductor.
The general strategy is that an unspecialized int is represented as a SymInt. Representing it as a 0d tensor (which is what the code used to do) is untenable: (1) we often need unspecialized ints to participate in size computations, but we have no way of propagating sympy expressions through tensor compute, and (2) a lot of APIs work when passed SymInt, but not when passed a Tensor. However, I continue to represent Numpy scalars as Tensors, as they are rarely used for size computation and they have an explicit dtype, so they are more accurately modeled as 0d tensors.
* I folded in the changes from https://github.com/pytorch/pytorch/pull/95099 as I cannot represent unspecialized ints as SymInts without also turning on dynamic shapes. This also eliminates the necessity for test_unspec.py, as toggling specialization without dynamic shapes doesn't do anything. As dynamic shapes defaults to unspecializing, I just deleted this entirely; for the specialization case, I rely on regular static shape tests to catch it. (Hypothetically, we could also rerun all the tests with dynamic shapes, but WITH int/float specialization, but this seems... not that useful? I mean, I guess export wants it, but I'd kind of like our Source heuristic to improve enough that export doesn't have to toggle this either.)
* Only 0/1 integers get specialized by default now
* A hodgepodge of fixes. I'll comment on the PR about them.
Fixes https://github.com/pytorch/pytorch/issues/95469
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95621
Approved by: https://github.com/jansel, https://github.com/Chillee
This generates compilable code for maskrcnn graph 13, with ceilings hoisted to be computed on the host. But it now fails with
```
File "/scratch/ngimel/work/pytorch/torch/_dynamo/symbolic_convert.py", line 379, in wrapper
self.output.compile_subgraph(self, reason=reason)
File "/scratch/ngimel/work/pytorch/torch/_dynamo/output_graph.py", line 562, in compile_subgraph
pass1.foreach(stack_values)
File "/scratch/ngimel/work/pytorch/torch/_dynamo/codegen.py", line 166, in foreach
self(i)
File "/scratch/ngimel/work/pytorch/torch/_dynamo/codegen.py", line 148, in __call__
output.extend(value.reconstruct(self))
File "/scratch/ngimel/work/pytorch/torch/_dynamo/variables/dicts.py", line 40, in reconstruct
codegen.create_load_python_module(collections),
TypeError: create_load_python_module() missing 1 required positional argument: 'push_null'
from user code:
File "/scratch/ngimel/work/env/lib/python3.9/site-packages/torchvision-0.15.0a0+928b05c-py3.9-linux-x86_64.egg/torchvision/models/detection/backbone_utils.py", line 58, in forward
x = self.fpn(x)
```
looks like we never execute this `create_load_python_module()` path for other subgraphs.
Any advice on how to fix this @voznesenskym @jansel ?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95690
Approved by: https://github.com/jansel