Commit Graph

225 Commits

Author SHA1 Message Date
Jason Ansel
4f9399bd0d [halide-backend] Initial implementation of HalideKernel and HalideScheduling (#126417)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126417
Approved by: https://github.com/shunting314, https://github.com/eellison
2024-06-22 17:39:52 +00:00
Jason Ansel
feb3f3ad77 [inductor] Refactors for Halide backend (#129024)
Pulling these inductor-related refactors out of the larger Halide
backend PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129024
Approved by: https://github.com/shunting314, https://github.com/eellison
2024-06-21 16:53:35 +00:00
eellison
c187593418 Prevent expansion of cat indexing to avoid int64 intermediate (#127815)
Fix for https://github.com/pytorch/pytorch/issues/127652

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127815
Approved by: https://github.com/shunting314, https://github.com/peterbell10
2024-06-14 15:42:08 +00:00
Isuru Fernando
e397ad6883 Improve codegen for ops.masked in triton (#128054)
Fixes https://github.com/pytorch/pytorch/issues/127930
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128054
Approved by: https://github.com/peterbell10, https://github.com/lezcano
2024-06-14 11:52:56 +00:00
Jason Ansel
c897651392 [inductor] Add BackendFeature gating (#128266)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128266
Approved by: https://github.com/shunting314
2024-06-13 07:31:51 +00:00
PyTorch MergeBot
f2dcbe89d6 Revert "Prevent expansion of cat indexing to avoid int64 intermediate (#127815)"
This reverts commit 793df7b7cb.

Reverted https://github.com/pytorch/pytorch/pull/127815 on behalf of https://github.com/clee2000 due to the newly added test is failing internally D58444153.  Test exists in opensource and passed in OSS CI, maybe env difference? ([comment](https://github.com/pytorch/pytorch/pull/127815#issuecomment-2163421968))
2024-06-12 16:09:22 +00:00
eellison
793df7b7cb Prevent expansion of cat indexing to avoid int64 intermediate (#127815)
Fix for https://github.com/pytorch/pytorch/issues/127652

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127815
Approved by: https://github.com/shunting314, https://github.com/peterbell10
2024-06-11 02:41:07 +00:00
Edward Z. Yang
3964a3ec73 Complete revamp of float/promotion sympy handling (#126905)
At a high level, the idea behind this PR is:

* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.

The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:

* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)

In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations.  Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.

We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:

* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`

These changes have consequences. First, we need to make some administrative changes:

* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
  * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
  * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.

In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:

* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type

The new asserts uncovered necessary bug fixes:

* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1

Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**

**Reland notes.** This requires this internal fbcode diff https://www.internalfb.com/phabricator/paste/view/P1403322587 but I cannot prepare the diff codev due to https://fb.workplace.com/groups/osssupport/posts/26343544518600814/

It also requires this Executorch PR https://github.com/pytorch/executorch/pull/3911 but the ET PR can be landed prior to this landing.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
2024-06-09 06:20:25 +00:00
Aaron Orenstein
ea614fb2b1 Flip default value for mypy disallow_untyped_defs [2/11] (#127839)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127839
Approved by: https://github.com/oulgen
2024-06-08 18:23:08 +00:00
PyTorch MergeBot
ac51f782fe Revert "Complete revamp of float/promotion sympy handling (#126905)"
This reverts commit 2f7cfecd86.

Reverted https://github.com/pytorch/pytorch/pull/126905 on behalf of https://github.com/atalman due to Sorry need to revert - failing internally ([comment](https://github.com/pytorch/pytorch/pull/126905#issuecomment-2155118778))
2024-06-07 16:01:46 +00:00
Edward Z. Yang
2f7cfecd86 Complete revamp of float/promotion sympy handling (#126905)
At a high level, the idea behind this PR is:

* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.

The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:

* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)

In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations.  Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.

We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:

* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`

These changes have consequences. First, we need to make some administrative changes:

* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
  * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
  * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.

In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:

* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type

The new asserts uncovered necessary bug fixes:

* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1

Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
2024-06-06 02:29:45 +00:00
PyTorch MergeBot
d5cb5d623a Revert "Complete revamp of float/promotion sympy handling (#126905)"
This reverts commit fb696ef3aa.

Reverted https://github.com/pytorch/pytorch/pull/126905 on behalf of https://github.com/ezyang due to internal user reported ceiling equality simplification problem, I have a plan ([comment](https://github.com/pytorch/pytorch/pull/126905#issuecomment-2148805840))
2024-06-05 03:57:58 +00:00
Edward Z. Yang
fb696ef3aa Complete revamp of float/promotion sympy handling (#126905)
At a high level, the idea behind this PR is:

* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.

The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:

* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)

In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations.  Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.

We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:

* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`

These changes have consequences. First, we need to make some administrative changes:

* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
  * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
  * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.

In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:

* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type

The new asserts uncovered necessary bug fixes:

* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1

Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
2024-06-04 11:47:32 +00:00
Edward Z. Yang
a4064da8ca Always simplify sympy expressions before printing. (#127543)
This is important because if a replacement has happened during inductor lowering, we may have stale symbols in sympy expressions that we need to replace away.  Do this at the very end.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127543
Approved by: https://github.com/lezcano
2024-06-03 20:36:14 +00:00
lezcano
0fa2c5b049 Fix mask propagation in the presence of where (#125574)
Before, when calling ops.where, masks were not properly propagated. We
now restrict the optimisation to `ops.masked`, which I think it was what
the original code intended to do.

I'm not 100% sure that even in the masked case this code is not
introducing some bugs, but this is a strict improvement over the
previous state.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125574
Approved by: https://github.com/peterbell10
ghstack dependencies: #114471, #126783
2024-05-29 23:17:41 +00:00
Jiong Gong
92bc444ee3 [inductor][cpp] epilogue support for gemm template (#126019)
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126019
Approved by: https://github.com/jansel
ghstack dependencies: #124021
2024-05-29 09:12:03 +00:00
lezcano
8a21532e53 Fix constant propagation pass (#114471)
This pass was broken in a number of ways, as we were not generating
asserts whenever we took it, even though we need to. While doing so,
we found that the analysis we were using for choosing
whether to generate asserts or not for dynamic shapes was completely
broken.

Eliminating indirect indexing in this way allows for a number of optimisations.
In particular, we can now fuse against these kernels (indirect indexing disallows fusions).

The new strategy is as follows:

- We always propagate sympy expressions if we can.
- If an expression was an indirect_indexing, we call `check_bounds`
- We also call `check_bounds` within `CSEProxy.indirect_indexing`
- The checks are issued in the buffer where they would go if the were used in a load
   - This makes them always be codegen'd before the load and stores
   - In the case of stores, they will be generated potentially much earlier than the stores themselves, which is fine.

We add quite a few asserts to preexisting tests to strengthen them. In particular, we make sure
that issuing an assert plays well with all kinds of C++ vectorisation.

For now, we rely on the logic within `_maybe_evaluate_static` to prove
these bounds. This logic is rather limited though. In the future, we might want
to rely on Z3 here to be able to prove bounds in a more general way.

Supersedes https://github.com/pytorch/pytorch/pull/113068
Fixes https://github.com/pytorch/pytorch/issues/121251

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114471
Approved by: https://github.com/peterbell10
2024-05-29 09:10:25 +00:00
PyTorch MergeBot
343a41fba8 Revert "[inductor][cpp] epilogue support for gemm template (#126019)"
This reverts commit 56c412d906.

Reverted https://github.com/pytorch/pytorch/pull/126019 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/124021#issuecomment-2133002071))
2024-05-27 09:01:45 +00:00
Jason Ansel
92433217cb [inductor] Misc refactors (#126945)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126945
Approved by: https://github.com/shunting314
ghstack dependencies: #126944
2024-05-24 22:46:20 +00:00
Jiong Gong
56c412d906 [inductor][cpp] epilogue support for gemm template (#126019)
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126019
Approved by: https://github.com/jansel
ghstack dependencies: #124021
2024-05-24 12:14:12 +00:00
PyTorch MergeBot
45784cd229 Revert "[inductor][cpp] epilogue support for gemm template (#126019)"
This reverts commit 08f57b4bff.

Reverted https://github.com/pytorch/pytorch/pull/126019 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](https://github.com/pytorch/pytorch/pull/124021#issuecomment-2126568331))
2024-05-23 08:50:18 +00:00
Jiong Gong
08f57b4bff [inductor][cpp] epilogue support for gemm template (#126019)
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126019
Approved by: https://github.com/jansel
ghstack dependencies: #124021
2024-05-23 07:39:29 +00:00
PyTorch MergeBot
657d39e44c Revert "[inductor][cpp] epilogue support for gemm template (#126019)"
This reverts commit 57108d9a49.

Reverted https://github.com/pytorch/pytorch/pull/126019 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think it has a land race and failing in trunk 2ac33a9f66 ([comment](https://github.com/pytorch/pytorch/pull/124021#issuecomment-2126016522))
2024-05-23 01:13:29 +00:00
Jiong Gong
57108d9a49 [inductor][cpp] epilogue support for gemm template (#126019)
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126019
Approved by: https://github.com/jansel
ghstack dependencies: #124021
2024-05-23 00:07:52 +00:00
Shunting Zhang
14c5c753de [inductor] use smaller RBLOCK for expensive reduction kernels (#126477)
Triton sometimes uses less registers for more expensive kernel which results in worse perf ( https://github.com/pytorch/pytorch/issues/126463 ). This may make inductor end up with a sub-optimal config. Use a smaller max RBLOCK if the reduction potentially need many registers.

Will run perf test..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126477
Approved by: https://github.com/jansel
2024-05-22 22:47:10 +00:00
Bin Bao
b40fb2de59 [AOTI] Fix a codegen issue when .item() is used for kernel arg (#126575)
Summary: fixes https://github.com/pytorch/pytorch/issues/126574 . Pass kernel argument type information into generate_args_decl, so it can generate the argument declaration instead of relying on string matching.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126575
Approved by: https://github.com/chenyang78
ghstack dependencies: #126369
2024-05-21 18:20:20 +00:00
Edward Z. Yang
55033ab43a Update ops handler documentation some more (#126480)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126480
Approved by: https://github.com/peterbell10
ghstack dependencies: #126292, #126299
2024-05-17 13:31:44 +00:00
PyTorch MergeBot
4a5ef0b793 Revert "[inductor][cpp] epilogue support for gemm template (#126019)"
This reverts commit 7844c202b2.

Reverted https://github.com/pytorch/pytorch/pull/126019 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but the dependency PR https://github.com/pytorch/pytorch/pull/124021 is going to be revert ([comment](https://github.com/pytorch/pytorch/pull/126019#issuecomment-2116408137))
2024-05-17 00:15:00 +00:00
Jiong Gong
7844c202b2 [inductor][cpp] epilogue support for gemm template (#126019)
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126019
Approved by: https://github.com/jansel
2024-05-16 01:42:29 +00:00
Edward Z. Yang
2ba102f689 Implement native support for float inputs in Dynamo and ShapeEnv (#125325)
The big idea is that floats are treated as Tensors on input/output to the FX graph, but on the inside, we immediately call item() on the synthetic Tensor and record regular float operations on it. Canonicalization to Tensor operations will happen in a standalone FX pass. This behavior is controlled by `specialize_float` config variable when set to False.

The generated graph looks like this for the test `test_unspec_float_output`:

```
 def forward(self, L_x_: "f32[3]", L_y_: "f32[]"):
     l_x_ = L_x_
     l_y_ = L_y_

     # File: /data/users/ezyang/a/pytorch/test/dynamo/test_unspec.py:511 in f, code: return x + 1, y * 2
     add: "f32[3]" = l_x_ + 1;  l_x_ = None
     item: "Sym(zf0)" = l_y_.item();  l_y_ = None
     mul: "Sym(2*zf0)" = item * 2;  item = None
     scalar_tensor: "f32[]" = torch.scalar_tensor(mul);  mul = None
     return (add, scalar_tensor)
```

The ingredients:

* **torch/_dynamo/variables/builder.py** When `specialize_float` is False, we wrap float literals with `wrap_symfloat`. This is an unholy mashup of `wrap_symint` and `wrap_unspecialized_primitive`. The overall strategy is that we first generate a tensor argument (because that's what we want to show up into the FX graph), but then immediately call item() on the tensor argument to get a SymNodeVariable, which we will do the rest of the tracing with.  Importantly, this SymNodeVariable is backed with the source of the original float: this means we can guard on the resulting value (something we could NOT do with UnspecializedPythonVariable). This has to be done manually, because if you literally call item() on the tensor, you will end up with an unbacked float. There is a bit of copy paste from wrap_symint and wrap_unspecialized_primitive which we can try to factor out, but this really is its own thing and you should review every line of code in the function.
* **torch/fx/experimental/symbolic_shapes.py** We now can generate guards on float inputs, and these guards are handled inside of ShapeEnv. So we need to be able to allocate (backed!) float symbols, and produce guards for them. Fairly straightforward generalization.
* **torch/_dynamo/codegen.py** I also need to maintain the invariant that there are no float outputs to the FX graph. I chose to do this at codegen time. When we detect a SymNodeVariable on the return stack for a float, we on the fly convert it (via `as_tensor`) to a TensorVariable, which is the true output. We then special case the output bytecode to call item() on it again. The tensor conversion is memoized on SymNodeVariable since we typically run the code generation process twice.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125325
Approved by: https://github.com/lezcano, https://github.com/jansel
2024-05-14 04:10:01 +00:00
lezcano
320af5eaa6 Compute bounds for the variables created during codegen (#123100)
Before we would just bail out on these bounds for all variables that did
not come from the FX graph. Now we propagate the bounds whenever we have
a rule for that op.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123100
Approved by: https://github.com/jgong5, https://github.com/peterbell10
2024-05-08 08:14:06 +00:00
PyTorch MergeBot
2a42c40791 Revert "Compute bounds for the variables created during codegen (#123100)"
This reverts commit bb668c6468.

Reverted https://github.com/pytorch/pytorch/pull/123100 on behalf of https://github.com/huydhn due to Sorry for reverting you change but it is failing inductor tests bb668c6468 ([comment](https://github.com/pytorch/pytorch/pull/123100#issuecomment-2096837821))
2024-05-06 20:23:39 +00:00
lezcano
bb668c6468 Compute bounds for the variables created during codegen (#123100)
Before we would just bail out on these bounds for all variables that did
not come from the FX graph. Now we propagate the bounds whenever we have
a rule for that op.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123100
Approved by: https://github.com/jgong5, https://github.com/peterbell10
2024-05-06 18:12:15 +00:00
Jiong Gong
68a1f787c8 [inductor][cpp] move some common cpp utils to cpp_utils.py (#125152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125152
Approved by: https://github.com/desertfire, https://github.com/jansel
2024-05-06 04:30:30 +00:00
Edward Z. Yang
6f70d22277 Extend torch.utils._sympy.symbol for more Inductor symbols (#125419)
I'm still missing a few, cdzq at least

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125419
Approved by: https://github.com/lezcano
ghstack dependencies: #125395
2024-05-04 09:05:00 +00:00
Edward Z. Yang
5503c29357 Introduce torch.utils._sympy.symbol (#125395)
This provides utilities for creating and querying properties on
sympy.Symbol.  I want to use this refactor to get a better handle on how
the 's' prefix is being used in Inductor.  To start, I only do
symbolic_shapes code because that's what I'm familiar with.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125395
Approved by: https://github.com/Skylion007
2024-05-03 21:24:23 +00:00
Edward Z. Yang
dae574c713 Don't make replacements for i variables (#125398)
This was introduced in https://github.com/pytorch/pytorch/pull/110262
but actually it looks like they were trying to hit unbacked SymInt.
Now that unbacked SymInt is renamed to u, this code is no longer
necessary

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125398
Approved by: https://github.com/lezcano, https://github.com/Skylion007
2024-05-02 20:38:09 +00:00
haozhe.zhu
c5b1a4c269 [inductor] share more cse cache during swap buffer (#124921)
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope.
For example,

```
auto tmp8 = -std::numeric_limits<float>::infinity();
auto tmp9 = [&]
{
    auto tmp12 = -std::numeric_limits<float>::infinity();
    return tmp12;
}
```
`tmp12` should not be created since it is same with `tmp8`.

We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.)

**Test Plan**
```
python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu
```
the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope.

Before this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = c10::convert<int>(x3);
                                    auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1);
                                    auto tmp21 = static_cast<int>(256);
                                    auto tmp22 = at::vec::Vectorized<int>(tmp21);
                                    auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22);
                                    auto tmp25 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp26 = tmp23 & tmp25;
                                    auto tmp24 = [&]
                                    {
                                        auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp27;
                                    }
                                    ;
                                    auto tmp28 =
                                    [&]
                                    {
                                        if (tmp26.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp29 = static_cast<float>(0.0);
                                    auto tmp30 = at::vec::Vectorized<float>(tmp29);
                                    auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>());
                                    return tmp31;
                                }
                                ;
                                auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp33 = static_cast<float>(0.0);
                                auto tmp34 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp35 = at::vec::Vectorized<float>(tmp33);
                                auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>());
                                auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>());
                                return tmp37;
                            }
                            ;
                            auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp40 = static_cast<int>(3);
                            auto tmp41 = tmp39 < tmp40;
                            auto tmp42 = [&]
                            {
                                auto tmp43 = c10::convert<int>(x3);
                                auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1);
                                auto tmp45 = static_cast<int>(256);
                                auto tmp46 = at::vec::Vectorized<int>(tmp45);
                                auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46);
                                auto tmp49 = at::vec::VecMask<float,1>::from(tmp41);
                                auto tmp50 = tmp47 & tmp49;
                                auto tmp48 = [&]
                                {
                                    auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp51;
                                }
                                ;
                                auto tmp52 =
                                [&]
                                {
                                    if (tmp50.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp53 = static_cast<float>(0.0);
                                auto tmp54 = at::vec::Vectorized<float>(tmp53);
                                auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>());
                                return tmp55;
                            }
                            ;
                            auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp57 = static_cast<float>(0.0);
                            auto tmp58 = at::vec::VecMask<float,1>::from(tmp41);
                            auto tmp59 = at::vec::Vectorized<float>(tmp57);
                            auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>());
                            auto tmp61 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>());
                            tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = c10::convert<int64_t>(x3);
                                    auto tmp15 = static_cast<int64_t>(256);
                                    auto tmp16 = tmp14 >= tmp15;
                                    auto tmp17 = [&]
                                    {
                                        auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp18;
                                    }
                                    ;
                                    auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0);
                                    auto tmp20 = static_cast<float>(0.0);
                                    auto tmp21 = tmp16 ? tmp19 : tmp20;
                                    return tmp21;
                                }
                                ;
                                auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp23 = static_cast<float>(0.0);
                                auto tmp24 = tmp12 ? tmp22 : tmp23;
                                auto tmp25 = tmp6 ? tmp9 : tmp24;
                                return tmp25;
                            }
                            ;
                            auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp28 = static_cast<int64_t>(3);
                            auto tmp29 = tmp27 < tmp28;
                            auto tmp30 = [&]
                            {
                                auto tmp31 = c10::convert<int64_t>(x3);
                                auto tmp32 = static_cast<int64_t>(256);
                                auto tmp33 = tmp31 >= tmp32;
                                auto tmp34 = [&]
                                {
                                    auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp35;
                                }
                                ;
                                auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0);
                                auto tmp37 = static_cast<float>(0.0);
                                auto tmp38 = tmp33 ? tmp36 : tmp37;
                                return tmp38;
                            }
                            ;
                            auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0);
                            auto tmp40 = static_cast<float>(0.0);
                            auto tmp41 = tmp29 ? tmp39 : tmp40;
                            auto tmp42 = tmp2 ? tmp26 : tmp41;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42;
                        }
                    }
                }
            }
        }
    }
}
''')
```
After this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = at::vec::Vectorized<int>(tmp1);
                                    auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19);
                                    auto tmp22 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp23 = tmp20 & tmp22;
                                    auto tmp21 = [&]
                                    {
                                        auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp24;
                                    }
                                    ;
                                    auto tmp25 =
                                    [&]
                                    {
                                        if (tmp23.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp26 = static_cast<float>(0.0);
                                    auto tmp27 = at::vec::Vectorized<float>(tmp26);
                                    auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>());
                                    return tmp28;
                                }
                                ;
                                auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp30 = static_cast<float>(0.0);
                                auto tmp31 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp32 = at::vec::Vectorized<float>(tmp30);
                                auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>());
                                auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>());
                                return tmp34;
                            }
                            ;
                            auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp37 = static_cast<int>(3);
                            auto tmp38 = tmp36 < tmp37;
                            auto tmp39 = [&]
                            {
                                auto tmp40 = c10::convert<int>(x3);
                                auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1);
                                auto tmp42 = at::vec::Vectorized<int>(tmp1);
                                auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42);
                                auto tmp45 = at::vec::VecMask<float,1>::from(tmp38);
                                auto tmp46 = tmp43 & tmp45;
                                auto tmp44 = [&]
                                {
                                    auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp47;
                                }
                                ;
                                auto tmp48 =
                                [&]
                                {
                                    if (tmp46.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp49 = static_cast<float>(0.0);
                                auto tmp50 = at::vec::Vectorized<float>(tmp49);
                                auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>());
                                return tmp51;
                            }
                            ;
                            auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp53 = static_cast<float>(0.0);
                            auto tmp54 = at::vec::VecMask<float,1>::from(tmp38);
                            auto tmp55 = at::vec::Vectorized<float>(tmp53);
                            auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>());
                            auto tmp57 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>());
                            tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = tmp4 >= tmp1;
                                    auto tmp15 = [&]
                                    {
                                        auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp16;
                                    }
                                    ;
                                    auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0);
                                    auto tmp18 = static_cast<float>(0.0);
                                    auto tmp19 = tmp14 ? tmp17 : tmp18;
                                    return tmp19;
                                }
                                ;
                                auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp21 = static_cast<float>(0.0);
                                auto tmp22 = tmp12 ? tmp20 : tmp21;
                                auto tmp23 = tmp6 ? tmp9 : tmp22;
                                return tmp23;
                            }
                            ;
                            auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp26 = static_cast<int64_t>(3);
                            auto tmp27 = tmp25 < tmp26;
                            auto tmp28 = [&]
                            {
                                auto tmp29 = c10::convert<int64_t>(x3);
                                auto tmp30 = tmp29 >= tmp1;
                                auto tmp31 = [&]
                                {
                                    auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp32;
                                }
                                ;
                                auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0);
                                auto tmp34 = static_cast<float>(0.0);
                                auto tmp35 = tmp30 ? tmp33 : tmp34;
                                return tmp35;
                            }
                            ;
                            auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0);
                            auto tmp37 = static_cast<float>(0.0);
                            auto tmp38 = tmp27 ? tmp36 : tmp37;
                            auto tmp39 = tmp2 ? tmp24 : tmp38;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39;
                        }
                    }
                }
            }
        }
    }
}
''')
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124921
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: #124597
2024-04-28 04:33:25 +00:00
Sergii Dymchenko
f0f7452e31 Do not propogate (#124769)
Fix the propogate typos.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124769
Approved by: https://github.com/Skylion007
2024-04-24 02:18:18 +00:00
lezcano
9a5b4d2403 Do not forward parent's value range to CSE variable for variables created within codegen. (#123099)
Consider we are generating code for `ops.gt`, and within it we call
`ops.to_dtype`. Before, we would forward the bounds from `gt` to the
to the result of `to_dtype`, which is wrong.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123099
Approved by: https://github.com/jgong5, https://github.com/peterbell10
2024-04-23 06:26:39 +00:00
Xuehai Pan
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
Peter Bell
bd225189f1 [inductor] Change OverridesData to take callables instead of strings (#123397)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123397
Approved by: https://github.com/lezcano
2024-04-11 22:22:54 +00:00
Edward Z. Yang
efa36ef092 Natively support int truncation, don't guard on positive/negative (#122827)
This doesn't entirely fix the original problem that prompted this, but
it seems to just be getting stuck in export constraint formatting now
which seems like progress to me.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122827
Approved by: https://github.com/avikchaudhuri
2024-04-11 15:22:32 +00:00
Peter Bell
9189d04cb1 [inductor] Add explicit ops.fma and use it in softmax_backward (#122518)
This allows us to generate an fma even when fp-fusion is disabled
in the compiler.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122518
Approved by: https://github.com/lezcano, https://github.com/Chillee
2024-04-06 02:15:16 +00:00
drisspg
f4e2a226aa ScoreMod API (#121845)
# Summary

This PR adds a new higher-order_op: `templated_attention`.  This op is designed to extend the functionality of torch.nn.fucntional.scaled_dot_product_attention.  PyTorch has efficient pre-written fused-attention kernels. However, users want to modify how scores are computed (a substep inside attention) -- this traditionally requires the user to write their own attention kernel. One such modification to attention scores that is not currently supported by the top level SDPA op is:[ Attention with Linear Biases (ALiBi](https://arxiv.org/abs/2108.12409)).

This higher-order op will instead accept a callable( 'score_mod') function that is through torch.compile will be used to create an efficient attention kernel instantiation.

### Details

This HOP utilizes the existing fx and HOP infra to capture and convert the User `score-mod` function and convert to an FX graph module. Inductor then consumes this HOP that has a `ir.Subgraph` input. It will inline this lowered subgraph into a triton kernel which performs fused attention with the modification to the scores matrix inlined.

### API

The API for a score_mod function should be as follows:

```Python
def score_mod(score: torch.Tensor, batch: torch.Tensor, head: torch.Tensor, token_1: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```

This function receives five parameters:

- `score`: A scalar tensor representing the attention score, with the same data type and device as the query, key, and value tensors.
- `batch`, `head`, `seq_len_q`, `seq_len_kv`: Scalar tensors indicating the batch index, head index, query index, and key/value index, respectively, with torch.int data type and located on the same device as the score tensor.

Consider inputs query, key, value of shapes (2, 4, 16, 8), leading to an intermediate attention score matrix of shape (2, 4, 16, 16)

The score_mod function will be vectorized over each element of this matrix. For instance, modifying the score at the position corresponding to the 0th batch, 2nd head, between the 8th query and the 9th key element, would be invoked as:

```Python
score_mod(score[0,2,8,9], torch.tensor(0), torch.tensor(2), torch.tensor(8), torch.tensor(9))
```

### Examples
```Python
import torch
from torch.nn.attention.templated_attention import templated_attention

torch.manual_seed(0)

# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim)
query = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
key = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
value = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)

# Lets create a fun new score_modification! I will call this
# Checkerboard. It will reduce the score for neighboring tokens (1 step apart)
# in the sequence. And increase the score for tokens 2 steps apart. For everything
# else, the score will remain the same.

def checkerboard(score, batch, head, token_q, token_kv):
    score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
    score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
    return score

# Lets call templated_attention with this new score modification
output = templated_attention(query, key, value, score_mod=checkerboard)

compiled_templated_attention = torch.compile(templated_attention)
out_compiled = compiled_templated_attention(query, key, value, score_mod=checkerboard)

torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```

### Future Work
- This PR is currently only forward only. However the triton kernel for backwards where score_modifications to not rely on external buffers has been explored here: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/flash/flash_attention.py
- Kernel Improvements; There are has been some larger updates to the fused attention implementation that Triton uses in its tutorials. The implementation of this kernel is based on a prior version and should be updated.
- We may want to unify this API under the top level SDPA API and leave that as a follow up once this is more stable
- Should we error on CPU?
- There are some issues with dynamic shapes
- Capturing of free variables and lifting to inputs to the subgraph is not working correctly today

### Performance
Comparisons generated by this benchmark:

| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------|
| Average |     5.412 |              |             |             |             |            |               |                |
| Max     |     8.882 |           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |
| Min     |     3.645 |            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |
| Min     |     0.345 |            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |

For reference

| Configuration                                 | Forward Time (µ seconds) | Backend          | Speedup |
|-----------------------------------------------|--------------------------|------------------|---------|
| Fastest Config in Sweep (`8 16 4096 4096 64 relative_bias torch.bfloat16`) | 3608                   | Templated Attention                | 1.0  |
| Compiled SDPA (No Mask)                       | 9928                   | Math             | 2.75x   |
| Compiled SDPA (With Mask)                     | 11898                    | Math             | 3.29x   |
| Compiled SDPA (With Mask) | 8704                      | Memory Efficient Attention | 2.42x   |
| Compiled SDPA (No Mask) | 2548                     | FlashAttention2 | 0.706x   |

The speedups are measuring compiled templated attention speed versus different calls to torch.nn.functional.sdpa

<details>

<summary> FULL PERFORMANCE SWEEP NUMBERS </summary>

|   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |   eager_time |   compiled_time |   speedup |
|--------------|-------------|-------------|-------------|------------|---------------|----------------|--------------|-----------------|-----------|
|            1 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      331.444 |          67.221 |     4.931 |
|            1 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      335.300 |          64.187 |     5.224 |
|            1 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      352.039 |          63.806 |     5.517 |
|            1 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      371.699 |         711.349 |     0.523 |
|            1 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |      333.488 |          86.455 |     3.857 |
|            1 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |      322.363 |          82.469 |     3.909 |
|            1 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |      349.967 |          82.233 |     4.256 |
|            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |      486.359 |        1412.453 |     0.344 |
|            1 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |     2794.597 |         551.188 |     5.070 |
|            1 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |     3965.150 |         513.101 |     7.728 |
|            1 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |     2408.013 |         504.759 |     4.771 |
|            1 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |     6850.531 |       16733.675 |     0.409 |
|            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      441.939 |         123.576 |     3.576 |
|            8 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      560.379 |         116.710 |     4.801 |
|            8 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      421.172 |         115.825 |     3.636 |
|            8 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      994.492 |        2132.806 |     0.466 |
|            8 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     1436.430 |         309.495 |     4.641 |
|            8 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     1892.216 |         290.186 |     6.521 |
|            8 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     1360.665 |         282.956 |     4.809 |
|            8 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     3525.532 |        8359.702 |     0.422 |
|            8 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    22026.839 |        3864.604 |     5.700 |
|            8 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    31262.746 |        3609.551 |     8.661 |
|            8 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    20219.079 |        3480.402 |     5.809 |
|            8 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |    54654.647 |      116652.357 |     0.469 |
|           16 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      820.606 |         188.683 |     4.349 |
|           16 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |     1058.362 |         179.295 |     5.903 |
|           16 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      784.372 |         175.714 |     4.464 |
|           16 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |     1890.792 |        4212.877 |     0.449 |
|           16 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     2781.830 |         557.017 |     4.994 |
|           16 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     3694.050 |         525.249 |     7.033 |
|           16 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     2634.164 |         507.613 |     5.189 |
|           16 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     6959.917 |       15331.116 |     0.454 |
|           16 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    43889.096 |        7582.018 |     5.789 |
|           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    62784.293 |        7075.846 |     8.873 |
|           16 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    40308.606 |        6829.587 |     5.902 |
|           16 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |   108892.137 |      233090.953 |     0.467 |
</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121845
Approved by: https://github.com/Chillee, https://github.com/zou3519
2024-04-06 01:10:44 +00:00
xinan.lin
9743e3a19c [Inductor Intel GPU backend Upstream] Add Inductor Intel GPU backend. (#121895)
As the design in RFC https://github.com/pytorch/pytorch/issues/114856, this PR implemented Intel GPU Inductor backend by:
- Reuse WrapperCodegen and TritonScheduling for python wrapper and kernel code generation. And implenented device-specific code generation in XPUDeviceOpOverrides
- Reuse fx_pass, lowering, codecache, triton kernel auto-tuning, and compilation.

For the test case, this PR provided test/inductor/test_xpu_basic.py for basic inductor backend functionality testing.
We'll reuse all the existing Inductor test case in the next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121895
Approved by: https://github.com/EikanWang, https://github.com/jansel, https://github.com/desertfire
2024-04-05 09:05:11 +00:00
PyTorch MergeBot
16cb5d48dd Revert "[inductor] Add explicit ops.fma and use it in softmax_backward (#122518)"
This reverts commit 05984e642b.

Reverted https://github.com/pytorch/pytorch/pull/122518 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it starts failing in trunk 05984e642b ([comment](https://github.com/pytorch/pytorch/pull/122518#issuecomment-2038631010))
2024-04-05 02:09:32 +00:00
Peter Bell
05984e642b [inductor] Add explicit ops.fma and use it in softmax_backward (#122518)
This allows us to generate an fma even when fp-fusion is disabled
in the compiler.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122518
Approved by: https://github.com/lezcano, https://github.com/Chillee
ghstack dependencies: #121924
2024-04-04 20:53:14 +00:00
Peter Bell
09c72eaa3f [inductor] Remove identity from ops.scan (#119727)
Currently scan has an `init` argument which must be the identity of the
combine function. This isn't strictly necessary if we are more careful about
keeping track of the first element and avoid combining it with anything.

This does additionally require that there are no active load masks, since we can't
do the `where_cond` any more. However, this shouldn't be possible anyway since
scans are always realized and only fused via the scheduler.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119727
Approved by: https://github.com/lezcano
2024-04-01 22:47:26 +00:00
Edward Z. Yang
3178ba0dc9 Don't use sympy Float functions, use an opaque one with no reasoning (#122823)
Sympy simplifications don't obey floating point semantics, so don't
use Sympy for this.  Keep them as is, only evaluate with the reference
implementations when all arguments are known.

This may end up getting subsumed by some other changes later, but I
wanted to understand if this was easy and it seems to be easy.

This doesn't actually depend on the earlier diffs on the stack and I can detach it.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122823
Approved by: https://github.com/lezcano
2024-03-29 19:13:55 +00:00