**Summary**
Previously, we used `data_type_propagation` at the start of `codegen` to deduce the data type of each node and save this information in `node.meta[OptimizationContext.key]`. Then, we used this node metadata to update the cppcsevar data type in `update_on_args`. However, this method is not always correct. For example, in the codegen of `indirect_indexing` (see [here](096dc444ce/torch/_inductor/codegen/common.py (L1844))), we insert nodes on the fly and reuse the node of `indirect_indexing` to set the `cppcsevar` data type. In this PR, we plan to enhance the `cppcsevar` data type deduction:
- We will deduce the `cppcsevar` data type in `update_on_args` by reusing the code in `data_type_propagation`.
- To align the data type of scalar and vector variables, we previously always cast the scalar to the vector's data type. This caused a data type misalignment between `codegen` and `data_type_propagation`. We should use the same data type promotion logic to align the data types of scalar and vector variables.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130827
Approved by: https://github.com/jgong5, https://github.com/jansel
```
# Mode to emulate pytorch eager numerics for lower precision (fp16, bf16)
# Pytorch eager computes bf16/fp16 by upcasting inputs to fp32 and downcasting after
# For multiple, fused pointwise nodes, inductor will elide the intermediary upcasts and downcasts
# Typically this should be closer to fp64 ref numerics. However, it can be useful for debugging
# to emulate the eager numerics.
```
We add extra upcasts and downcasts for pointwise nodes that correspond to casts that existed in the original user program (excluding pointwise nodes that are emitted during decomposition). Since this is mostly for debugging, I added this information in the `meta` so that this mode does not have unintended side effects like changing pattern matching.
in theory there could also be some other casts with fused reduction -> reduction, although i havent seen this in practice as much. could be done as follow up. note: only works with cuda backend right now.
This mode was sufficient to eliminate compile differences from https://fb.workplace.com/groups/385893200869952/posts/464263173032954/?comment_id=465199259606012&reply_comment_id=465676792891592.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131595
Approved by: https://github.com/shunting314, https://github.com/bdhirsh, https://github.com/jansel
The conversion cache used for fixing https://github.com/pytorch/pytorch/issues/115260 depended on "store" which might be removed and ignored. This would lead to inconsistent code generated between vec and scalar kernels since we generate scalar kernel first followed by the vector kernel and the store buffer might be removed by the scalar and impacts the vector kernel codegen. This PR move the caching from "store" to the "to_dtype" calls which won't be impacted by the removed buffers.
`pytest -k test_consistent_remove_buffers test/inductor/test_cpu_repro.py`
before
```c++
extern "C" void kernel(const bfloat16* in_ptr0,
bfloat16* out_ptr1)
{
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<long>(x0), 16);
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = tmp1 + tmp1;
auto tmp3 = at::vec::convert<bfloat16>(tmp2);
auto tmp4 = at::vec::convert<float>(tmp3);
auto tmp5 = tmp1 + tmp4;
auto tmp6 = at::vec::convert<bfloat16>(tmp5);
tmp6.store(out_ptr1 + static_cast<long>(x0), 16);
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(64L); x0<static_cast<long>(65L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp1 = c10::convert<float>(tmp0);
auto tmp2 = decltype(tmp1)(tmp1 + tmp1);
auto tmp3 = c10::convert<bfloat16>(tmp2);
auto tmp4 = decltype(tmp1)(tmp1 + tmp2);
auto tmp5 = c10::convert<bfloat16>(tmp4);
out_ptr1[static_cast<long>(x0)] = tmp5;
}
}
}
```
after
```c++
extern "C" void kernel(const bfloat16* in_ptr0,
bfloat16* out_ptr1)
{
{
for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(in_ptr0 + static_cast<long>(x0), 16);
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = tmp1 + tmp1;
auto tmp3 = at::vec::convert<bfloat16>(tmp2);
auto tmp4 = tmp1 + tmp2;
auto tmp5 = at::vec::convert<bfloat16>(tmp4);
tmp5.store(out_ptr1 + static_cast<long>(x0), 16);
}
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(64L); x0<static_cast<long>(65L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
auto tmp1 = c10::convert<float>(tmp0);
auto tmp2 = decltype(tmp1)(tmp1 + tmp1);
auto tmp3 = c10::convert<bfloat16>(tmp2);
auto tmp4 = decltype(tmp1)(tmp1 + tmp2);
auto tmp5 = c10::convert<bfloat16>(tmp4);
out_ptr1[static_cast<long>(x0)] = tmp5;
}
}
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130677
Approved by: https://github.com/leslie-fang-intel
**Summary**
Support more than 1 Local Buffer in an outer loop fused node and also the case when multi global buffers sharing usage of same local buffer.
**TestPlan**
```
python -u -m pytest -s -v inductor/test_cpu_repro.py -k test_two_local_buffers_in_outer_loop_fusion
python -u -m pytest -s -v inductor/test_cpu_repro.py -k test_share_local_buffers_in_outer_loop_fusion
```
**Next Step**
- [✓] Support more than one Local Buffer/Global Buffer
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129121
Approved by: https://github.com/jgong5, https://github.com/peterbell10
ghstack dependencies: #126967
**Summary**
When check the vectorization status among 3 test suit, we found some operators disabled vectorization with message `Disabled vectorization: op: remainder`. In this PR, we add vectorization support of this op.
**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_vec_remainder
python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_int_div_vec
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129849
Approved by: https://github.com/jgong5, https://github.com/lezcano
ghstack dependencies: #130405
**Summary**
Support more than 1 Local Buffer in an outer loop fused node and also the case when multi global buffers sharing usage of same local buffer.
**TestPlan**
```
python -u -m pytest -s -v inductor/test_cpu_repro.py -k test_two_local_buffers_in_outer_loop_fusion
python -u -m pytest -s -v inductor/test_cpu_repro.py -k test_share_local_buffers_in_outer_loop_fusion
```
**Next Step**
- [✓] Support more than one Local Buffer/Global Buffer
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129121
Approved by: https://github.com/jgong5, https://github.com/peterbell10
ghstack dependencies: #126967
Currently a buffer represents both a tensor with physical storage and a
computation that produces the tensor as a result.
This PR attempts to split these into two different concepts in the scheduler.
This should allow us to have multiple outputs from a single operation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128893
Approved by: https://github.com/lezcano
**Summary**
When check the vectorization status among 3 test suit, we found some operators disabled vectorization with message `Disabled vectorization: op: bitwise_and`. In this PR, we add vectorization support of 6 bitwise functions.
In this PR, we also remove `bitwise_xor` from `ops_to_bool` list which sets output data type as bool in data type propagation. It seems wrong since according to this doc
https://pytorch.org/docs/stable/generated/torch.bitwise_xor.html, it should return the same integral data type with input and the testcase `test_bitwise3` failed due to this issue.
**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_vec_bitwise
python -u -m pytest -s -v test/inductor/test_torchinductor.py -k test_bitwise3
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129733
Approved by: https://github.com/jgong5, https://github.com/Skylion007
As part of #125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows:
1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling.
2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops.
3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen.
4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126545
Approved by: https://github.com/jansel
At a high level, the idea behind this PR is:
* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.
The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:
* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)
In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.
We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:
* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`
These changes have consequences. First, we need to make some administrative changes:
* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
* In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
* TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.
In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:
* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type
The new asserts uncovered necessary bug fixes:
* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1
Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**
**Reland notes.** This requires this internal fbcode diff https://www.internalfb.com/phabricator/paste/view/P1403322587 but I cannot prepare the diff codev due to https://fb.workplace.com/groups/osssupport/posts/26343544518600814/
It also requires this Executorch PR https://github.com/pytorch/executorch/pull/3911 but the ET PR can be landed prior to this landing.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
At a high level, the idea behind this PR is:
* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.
The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:
* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)
In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.
We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:
* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`
These changes have consequences. First, we need to make some administrative changes:
* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
* In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
* TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.
In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:
* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type
The new asserts uncovered necessary bug fixes:
* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1
Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
At a high level, the idea behind this PR is:
* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.
The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:
* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)
In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.
We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:
* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`
These changes have consequences. First, we need to make some administrative changes:
* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
* In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
* TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.
In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:
* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type
The new asserts uncovered necessary bug fixes:
* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1
Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126068
Approved by: https://github.com/jansel
ghstack dependencies: #124021, #126019
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126019
Approved by: https://github.com/jansel
ghstack dependencies: #124021
This pass was broken in a number of ways, as we were not generating
asserts whenever we took it, even though we need to. While doing so,
we found that the analysis we were using for choosing
whether to generate asserts or not for dynamic shapes was completely
broken.
Eliminating indirect indexing in this way allows for a number of optimisations.
In particular, we can now fuse against these kernels (indirect indexing disallows fusions).
The new strategy is as follows:
- We always propagate sympy expressions if we can.
- If an expression was an indirect_indexing, we call `check_bounds`
- We also call `check_bounds` within `CSEProxy.indirect_indexing`
- The checks are issued in the buffer where they would go if the were used in a load
- This makes them always be codegen'd before the load and stores
- In the case of stores, they will be generated potentially much earlier than the stores themselves, which is fine.
We add quite a few asserts to preexisting tests to strengthen them. In particular, we make sure
that issuing an assert plays well with all kinds of C++ vectorisation.
For now, we rely on the logic within `_maybe_evaluate_static` to prove
these bounds. This logic is rather limited though. In the future, we might want
to rely on Z3 here to be able to prove bounds in a more general way.
Supersedes https://github.com/pytorch/pytorch/pull/113068
Fixes https://github.com/pytorch/pytorch/issues/121251
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114471
Approved by: https://github.com/peterbell10
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC https://github.com/pytorch/pytorch/issues/125683 for more background info.
1. Cpp template infrastructure
Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates.
2. Initial FP32 gemm template
This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction.
3. Correctness and performance
The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details.
Static shapes
| Benchmark | torchbench | huggingface | timm_models |
|------------|-------------|--------------|--------------|
| Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x |
| Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x |
| Single-threaded (baseline) | 1.56x | 1.19x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x |
Key models being sped up:
drq: 1.14x
soft_act: 1.12
cait_m36_384: 1.18x
Dynamic shapes
| Benchmark | torchbench | huggingface | timm_models |
| --- | --- | --- | --- |
| Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x |
| Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x |
| Single-threaded (baseline) | 1.55x | 1.20x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x |
Key models being sped up:
BERT_pytorch: 1.22x
pyhpc_turbulent: 1.13x
soft_actor_critic: 1.77x
BlenderbotForCausalLM: 1.09x
cait_m36_384: 1.17x
Differential Revision: [D57585365](https://our.internmc.facebook.com/intern/diff/D57585365)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124021
Approved by: https://github.com/jansel