mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
2f7cfecd86
294 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
2f7cfecd86 |
Complete revamp of float/promotion sympy handling (#126905)
At a high level, the idea behind this PR is: * Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.) * Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers. The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions: * FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing). * ModularIndexing, LShift, RShift now assert they are given integer inputs. * Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver * TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division. * Trunc is split to TruncToFloat and TruncToInt. * Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result. * RoundDecimal updated to consistently only ever return a float * Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing) In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information. We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**: * `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy * `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv` These changes have consequences. First, we need to make some administrative changes: * Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2) * Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py** * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here * Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet * Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions. In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments: * Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now * `_assert_bound_is_rational` is no more, we no longer generate rational bounds * Don't intersect non-int value ranges with the `int_range` * Support more sympy Functions for guard SYMPY_INTERP * Assert the type of value range is consistent with the variable type The new asserts uncovered necessary bug fixes: * **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions * **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions * **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr! * **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1 Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py** Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905 Approved by: https://github.com/xadupre, https://github.com/lezcano |
||
|
|
d5cb5d623a |
Revert "Complete revamp of float/promotion sympy handling (#126905)"
This reverts commit
|
||
|
|
fb696ef3aa |
Complete revamp of float/promotion sympy handling (#126905)
At a high level, the idea behind this PR is: * Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.) * Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers. The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions: * FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing). * ModularIndexing, LShift, RShift now assert they are given integer inputs. * Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver * TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division. * Trunc is split to TruncToFloat and TruncToInt. * Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result. * RoundDecimal updated to consistently only ever return a float * Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing) In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information. We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**: * `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy * `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv` These changes have consequences. First, we need to make some administrative changes: * Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2) * Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py** * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here * Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet * Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions. In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments: * Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now * `_assert_bound_is_rational` is no more, we no longer generate rational bounds * Don't intersect non-int value ranges with the `int_range` * Support more sympy Functions for guard SYMPY_INTERP * Assert the type of value range is consistent with the variable type The new asserts uncovered necessary bug fixes: * **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions * **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions * **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr! * **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1 Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py** Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905 Approved by: https://github.com/xadupre, https://github.com/lezcano |
||
|
|
d1fad416a8 |
Revert "Add aten._unsafe_masked_index (#116491)"
This reverts commit
|
||
|
|
f03f8bc901 |
Add aten._unsafe_masked_index (#116491)
To generate masked indexing operations that would generate masked loads in triton code Pull Request resolved: https://github.com/pytorch/pytorch/pull/116491 Approved by: https://github.com/lezcano, https://github.com/peterbell10 |
||
|
|
e24a87ed8d |
[BE][Ez]: Apply PYI059 - Generic always come last (#127685)
Generic baseclass should always be last or unexpected issues can occur, especially in non-stub files (such as with MRO). Applies autofixes from the preview PYI059 rule to fix the issues in the codebase. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127685 Approved by: https://github.com/ezyang |
||
|
|
029b3ec775 |
Revert "[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion (#126068)"
This reverts commit
|
||
|
|
dae33a4961 |
[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion (#126068)
As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126068 Approved by: https://github.com/jansel ghstack dependencies: #124021, #126019 |
||
|
|
8a21532e53 |
Fix constant propagation pass (#114471)
This pass was broken in a number of ways, as we were not generating asserts whenever we took it, even though we need to. While doing so, we found that the analysis we were using for choosing whether to generate asserts or not for dynamic shapes was completely broken. Eliminating indirect indexing in this way allows for a number of optimisations. In particular, we can now fuse against these kernels (indirect indexing disallows fusions). The new strategy is as follows: - We always propagate sympy expressions if we can. - If an expression was an indirect_indexing, we call `check_bounds` - We also call `check_bounds` within `CSEProxy.indirect_indexing` - The checks are issued in the buffer where they would go if the were used in a load - This makes them always be codegen'd before the load and stores - In the case of stores, they will be generated potentially much earlier than the stores themselves, which is fine. We add quite a few asserts to preexisting tests to strengthen them. In particular, we make sure that issuing an assert plays well with all kinds of C++ vectorisation. For now, we rely on the logic within `_maybe_evaluate_static` to prove these bounds. This logic is rather limited though. In the future, we might want to rely on Z3 here to be able to prove bounds in a more general way. Supersedes https://github.com/pytorch/pytorch/pull/113068 Fixes https://github.com/pytorch/pytorch/issues/121251 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114471 Approved by: https://github.com/peterbell10 |
||
|
|
ad7700bfdb |
[inductor] Misc changes (#127307)
Pulling unrelated changes out of the larger halide PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/127307 Approved by: https://github.com/yanboliang |
||
|
|
cef776bcd1 |
[inductor][cpp] GEMM template (infra and fp32) (#124021)
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC https://github.com/pytorch/pytorch/issues/125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x Differential Revision: [D57585365](https://our.internmc.facebook.com/intern/diff/D57585365) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124021 Approved by: https://github.com/jansel |
||
|
|
ec8b254ef4 |
Refactored template codegen to explicitly set current body when generating code (#127144)
The main motivation for this refactor is that today, when generating templates, this is what happens.
```
def_kernel() # registers hook for fully generating function definition
store_output() # registers hook for generating the output store. *also* keeps a number of things generated on `self.body`.
```
Later on, when we codegen the template:
|
||
|
|
4608971f7a |
Revert "[inductor][cpp] GEMM template (infra and fp32) (#124021)"
This reverts commit
|
||
|
|
68fddebf84 |
Revert "[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion (#126068)"
This reverts commit
|
||
|
|
ba3b05fdf3 |
[1/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort stdlib (#127122)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127122 Approved by: https://github.com/kit1980 |
||
|
|
4aa43d11f3 |
[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion (#126068)
As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126068 Approved by: https://github.com/jansel ghstack dependencies: #124021, #126019 |
||
|
|
0d1e228550 |
[inductor][cpp] GEMM template (infra and fp32) (#124021)
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC https://github.com/pytorch/pytorch/issues/125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x Differential Revision: [D57585365](https://our.internmc.facebook.com/intern/diff/D57585365) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124021 Approved by: https://github.com/jansel |
||
|
|
25b8dbc3e4 |
Revert "[inductor][cpp] GEMM template (infra and fp32) (#124021)"
This reverts commit
|
||
|
|
926327e8fc |
Revert "[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion (#126068)"
This reverts commit
|
||
|
|
31412cb2f2 |
[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion (#126068)
As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126068 Approved by: https://github.com/jansel ghstack dependencies: #124021, #126019 |
||
|
|
9da7efa677 |
[inductor][cpp] GEMM template (infra and fp32) (#124021)
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC https://github.com/pytorch/pytorch/issues/125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x Differential Revision: [D57585365](https://our.internmc.facebook.com/intern/diff/D57585365) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124021 Approved by: https://github.com/jansel |
||
|
|
4f14282e35 |
Revert "[inductor][cpp] GEMM template (infra and fp32) (#124021)"
This reverts commit |
||
|
|
205f08140e |
Revert "[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion (#126068)"
This reverts commit |
||
|
|
57c185b4c7 |
[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion (#126068)
As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126068 Approved by: https://github.com/jansel ghstack dependencies: #124021, #126019 |
||
|
|
2ac33a9f66 |
[inductor][cpp] GEMM template (infra and fp32) (#124021)
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC https://github.com/pytorch/pytorch/issues/125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x Differential Revision: [D57585365](https://our.internmc.facebook.com/intern/diff/D57585365) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124021 Approved by: https://github.com/jansel |
||
|
|
e4623de4cf |
typing scheduler.py [2/2]: Apply types (#126656)
Add `# mypy: disallow-untyped-defs` to scheduler.py and then fix the resulting fallout. We probably should eventually add a new node between BaseSchedulerNode and all the non-FusedSchedulerNode types to indicate the split between nodes that have a valid `self.node` and ones that don't. That would cause a lot of the `assert self.node is not None` churn to go away - but was a bigger change because a lot of code makes assumptions about types that aren't reflected in the types themselves. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126656 Approved by: https://github.com/eellison |
||
|
|
bf099a08f0 |
[2/N] Non-Tensor: Scalar Support: Add scalar to the cache for eager-through-torch.compile (#124070)
Add scalar information to the kernel configuration. #### Additional Context Currently, the input parameters are orchestrated by input order in the kernel configuration and loaded/mapped to the kernel at runtime. For example, the cache order of the input parameters of `torch.add(a, b, alpha=2.0)` is `a' first, followed by `b` and then `alpha`. The same order is for cache loading. However, the orchestration mechanism does not support kwargs because the order of kwargs is useless. For example, the `out` of `aten::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)` may be before `approximate`. We will support it with subsequent PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124070 Approved by: https://github.com/jansel, https://github.com/jgong5 |
||
|
|
b98decfc38 |
[halide-backend] Refactor codegen/triton.py into codegen/simd.py (#126415)
This PR is primarily just moving stuff around. It creates a new common baseclass for TritonCodegen and the (upcoming) HalideCodegen. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126415 Approved by: https://github.com/shunting314 |
||
|
|
762ce6f062 |
Add Lowering for FlexAttention Backwards (#125515)
# Summary #### What does this PR do? It enables Inductor to actually generate the fused flex attention kernel for the backwards I did some other things along the way: - Abstract out the 'build_subgraph_buffer' subroutine and make it reusable between flex attention and flex_attention backwards. In total we need too build 3 subgraphs for fwd + bwd. 1 for the fwd graph and then 2 in the bwd. The FAv2 algorithm recomputes the parts of the forward (more efficiently since we already have the row_max via logsumexp), therefore we need to inline both the fwd graph and the joint graph in the bwds kernel. - The version of the backwards kernel is from a somewhat older version of the triton tutorial implementation. I think that we should update in a follow up to a newer version. Notably the blocks need to be square for this to work as currently implemented. I am sure there are many opportunities for optimization. - I didnt correctly register the decomp table + IndexMode when I landed: https://github.com/pytorch/pytorch/pull/123902, this remedies that. - The rel_bias helper func was reversed in terms of causality. I updated and then add a test specific for "future causal" attention. - This PRs but the main point that I think still needs to be worked out is the store_output call. I have it hacked up to be 'fake' but I dont think we want to land that and likely want to just have a mutated 'dq' and a stored_output 'dk' - I also needed to update the `TritonTemplateKernel` to actually accept multiple subgraphs (modifications) - I updated the benchmark to also profile bwds performance ### Benchmark Numbers: _The current implementation is not parallelizing over ctx length in the bwd_ FWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.991 | | | | | Max | 1.182 | (16, 16, 4096, 64) | noop | torch.bfloat16 | | Min | 0.796 | (2, 16, 512, 256) | head_bias | torch.bfloat16 | BWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.291 | | | | | Max | 0.652 | (8, 16, 512, 64) | head_bias | torch.bfloat16 | | Min | 0.073 | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | <details> <summary>Full Data</summary> | shape | score_mod | dtype | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup | |---------------------|---------------|----------------|------------------|---------------------|------------------|---------------------|---------------|---------------| | (2, 16, 512, 64) | noop | torch.bfloat16 | 19.936 | 19.092 | 57.851 | 193.564 | 1.044 | 0.299 | | (2, 16, 512, 64) | causal_mask | torch.bfloat16 | 19.955 | 19.497 | 57.662 | 206.278 | 1.024 | 0.280 | | (2, 16, 512, 64) | relative_bias | torch.bfloat16 | 19.455 | 21.297 | 57.674 | 195.219 | 0.913 | 0.295 | | (2, 16, 512, 64) | head_bias | torch.bfloat16 | 19.958 | 21.289 | 57.674 | 193.859 | 0.938 | 0.298 | | (2, 16, 512, 128) | noop | torch.bfloat16 | 28.157 | 28.615 | 82.831 | 454.211 | 0.984 | 0.182 | | (2, 16, 512, 128) | causal_mask | torch.bfloat16 | 28.154 | 28.444 | 83.091 | 432.083 | 0.990 | 0.192 | | (2, 16, 512, 128) | relative_bias | torch.bfloat16 | 28.722 | 27.897 | 83.175 | 446.789 | 1.030 | 0.186 | | (2, 16, 512, 128) | head_bias | torch.bfloat16 | 28.299 | 27.673 | 83.052 | 459.179 | 1.023 | 0.181 | | (2, 16, 512, 256) | noop | torch.bfloat16 | 41.167 | 50.504 | 175.019 | 1083.545 | 0.815 | 0.162 | | (2, 16, 512, 256) | causal_mask | torch.bfloat16 | 41.656 | 51.933 | 175.078 | 1171.176 | 0.802 | 0.149 | | (2, 16, 512, 256) | relative_bias | torch.bfloat16 | 41.697 | 50.722 | 175.159 | 1097.312 | 0.822 | 0.160 | | (2, 16, 512, 256) | head_bias | torch.bfloat16 | 41.690 | 52.387 | 175.184 | 1097.336 | 0.796 | 0.160 | | (2, 16, 1024, 64) | noop | torch.bfloat16 | 39.232 | 37.454 | 127.847 | 612.430 | 1.047 | 0.209 | | (2, 16, 1024, 64) | causal_mask | torch.bfloat16 | 39.930 | 39.599 | 127.755 | 665.359 | 1.008 | 0.192 | | (2, 16, 1024, 64) | relative_bias | torch.bfloat16 | 39.417 | 41.304 | 127.902 | 614.990 | 0.954 | 0.208 | | (2, 16, 1024, 64) | head_bias | torch.bfloat16 | 39.965 | 42.034 | 127.953 | 613.273 | 0.951 | 0.209 | | (2, 16, 1024, 128) | noop | torch.bfloat16 | 63.964 | 71.024 | 226.510 | 1637.669 | 0.901 | 0.138 | | (2, 16, 1024, 128) | causal_mask | torch.bfloat16 | 63.843 | 72.451 | 226.750 | 1558.949 | 0.881 | 0.145 | | (2, 16, 1024, 128) | relative_bias | torch.bfloat16 | 64.301 | 70.487 | 226.651 | 1610.063 | 0.912 | 0.141 | | (2, 16, 1024, 128) | head_bias | torch.bfloat16 | 64.033 | 71.394 | 226.676 | 1668.511 | 0.897 | 0.136 | | (2, 16, 1024, 256) | noop | torch.bfloat16 | 129.348 | 141.390 | 507.337 | 4405.175 | 0.915 | 0.115 | | (2, 16, 1024, 256) | causal_mask | torch.bfloat16 | 129.538 | 145.680 | 507.178 | 4768.874 | 0.889 | 0.106 | | (2, 16, 1024, 256) | relative_bias | torch.bfloat16 | 129.438 | 142.782 | 507.004 | 4401.002 | 0.907 | 0.115 | | (2, 16, 1024, 256) | head_bias | torch.bfloat16 | 129.058 | 146.242 | 507.547 | 4434.251 | 0.883 | 0.114 | | (2, 16, 4096, 64) | noop | torch.bfloat16 | 481.606 | 409.120 | 1440.890 | 14147.269 | 1.177 | 0.102 | | (2, 16, 4096, 64) | causal_mask | torch.bfloat16 | 480.227 | 438.847 | 1434.419 | 14973.386 | 1.094 | 0.096 | | (2, 16, 4096, 64) | relative_bias | torch.bfloat16 | 480.831 | 458.104 | 1432.935 | 14193.253 | 1.050 | 0.101 | | (2, 16, 4096, 64) | head_bias | torch.bfloat16 | 480.749 | 452.497 | 1437.040 | 14084.869 | 1.062 | 0.102 | | (2, 16, 4096, 128) | noop | torch.bfloat16 | 872.534 | 848.275 | 2600.895 | 35156.849 | 1.029 | 0.074 | | (2, 16, 4096, 128) | causal_mask | torch.bfloat16 | 872.647 | 868.279 | 2587.581 | 31919.531 | 1.005 | 0.081 | | (2, 16, 4096, 128) | relative_bias | torch.bfloat16 | 871.484 | 827.644 | 2593.989 | 34805.634 | 1.053 | 0.075 | | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | 871.422 | 856.437 | 2602.482 | 35708.591 | 1.017 | 0.073 | | (2, 16, 4096, 256) | noop | torch.bfloat16 | 1904.497 | 1758.183 | 6122.416 | 66754.593 | 1.083 | 0.092 | | (2, 16, 4096, 256) | causal_mask | torch.bfloat16 | 1911.174 | 1762.821 | 6113.207 | 72759.392 | 1.084 | 0.084 | | (2, 16, 4096, 256) | relative_bias | torch.bfloat16 | 1911.254 | 1727.108 | 6123.530 | 66577.988 | 1.107 | 0.092 | | (2, 16, 4096, 256) | head_bias | torch.bfloat16 | 1916.977 | 1801.804 | 6118.158 | 67359.680 | 1.064 | 0.091 | | (8, 16, 512, 64) | noop | torch.bfloat16 | 44.984 | 43.974 | 170.276 | 262.259 | 1.023 | 0.649 | | (8, 16, 512, 64) | causal_mask | torch.bfloat16 | 45.001 | 46.265 | 170.509 | 274.893 | 0.973 | 0.620 | | (8, 16, 512, 64) | relative_bias | torch.bfloat16 | 45.466 | 48.211 | 170.606 | 262.759 | 0.943 | 0.649 | | (8, 16, 512, 64) | head_bias | torch.bfloat16 | 45.481 | 48.435 | 170.267 | 261.265 | 0.939 | 0.652 | | (8, 16, 512, 128) | noop | torch.bfloat16 | 72.565 | 74.736 | 313.220 | 773.126 | 0.971 | 0.405 | | (8, 16, 512, 128) | causal_mask | torch.bfloat16 | 72.015 | 75.755 | 313.311 | 775.513 | 0.951 | 0.404 | | (8, 16, 512, 128) | relative_bias | torch.bfloat16 | 72.105 | 74.189 | 313.806 | 769.238 | 0.972 | 0.408 | | (8, 16, 512, 128) | head_bias | torch.bfloat16 | 72.005 | 74.364 | 313.509 | 775.237 | 0.968 | 0.404 | | (8, 16, 512, 256) | noop | torch.bfloat16 | 138.656 | 165.453 | 663.707 | 2672.067 | 0.838 | 0.248 | | (8, 16, 512, 256) | causal_mask | torch.bfloat16 | 139.096 | 172.613 | 663.593 | 2926.538 | 0.806 | 0.227 | | (8, 16, 512, 256) | relative_bias | torch.bfloat16 | 139.500 | 168.417 | 663.938 | 2658.629 | 0.828 | 0.250 | | (8, 16, 512, 256) | head_bias | torch.bfloat16 | 139.776 | 173.549 | 662.920 | 2667.266 | 0.805 | 0.249 | | (8, 16, 1024, 64) | noop | torch.bfloat16 | 134.883 | 125.004 | 484.706 | 1195.254 | 1.079 | 0.406 | | (8, 16, 1024, 64) | causal_mask | torch.bfloat16 | 134.297 | 132.875 | 485.420 | 1234.953 | 1.011 | 0.393 | | (8, 16, 1024, 64) | relative_bias | torch.bfloat16 | 134.839 | 139.231 | 485.470 | 1198.556 | 0.968 | 0.405 | | (8, 16, 1024, 64) | head_bias | torch.bfloat16 | 133.822 | 136.449 | 485.608 | 1189.198 | 0.981 | 0.408 | | (8, 16, 1024, 128) | noop | torch.bfloat16 | 235.470 | 234.765 | 886.094 | 2662.944 | 1.003 | 0.333 | | (8, 16, 1024, 128) | causal_mask | torch.bfloat16 | 236.305 | 241.382 | 886.293 | 2646.984 | 0.979 | 0.335 | | (8, 16, 1024, 128) | relative_bias | torch.bfloat16 | 236.414 | 233.980 | 885.250 | 2642.178 | 1.010 | 0.335 | | (8, 16, 1024, 128) | head_bias | torch.bfloat16 | 237.176 | 239.040 | 885.754 | 2665.242 | 0.992 | 0.332 | | (8, 16, 1024, 256) | noop | torch.bfloat16 | 504.445 | 517.855 | 1978.956 | 9592.906 | 0.974 | 0.206 | | (8, 16, 1024, 256) | causal_mask | torch.bfloat16 | 502.428 | 536.002 | 1978.611 | 10607.342 | 0.937 | 0.187 | | (8, 16, 1024, 256) | relative_bias | torch.bfloat16 | 503.396 | 523.960 | 1977.993 | 9539.284 | 0.961 | 0.207 | | (8, 16, 1024, 256) | head_bias | torch.bfloat16 | 503.818 | 536.014 | 1980.131 | 9576.262 | 0.940 | 0.207 | | (8, 16, 4096, 64) | noop | torch.bfloat16 | 1970.139 | 1674.930 | 5750.940 | 16724.134 | 1.176 | 0.344 | | (8, 16, 4096, 64) | causal_mask | torch.bfloat16 | 1959.036 | 1775.056 | 5780.512 | 17390.350 | 1.104 | 0.332 | | (8, 16, 4096, 64) | relative_bias | torch.bfloat16 | 1947.198 | 1773.869 | 5780.643 | 16779.699 | 1.098 | 0.345 | | (8, 16, 4096, 64) | head_bias | torch.bfloat16 | 1963.935 | 1829.502 | 5780.018 | 16703.259 | 1.073 | 0.346 | | (8, 16, 4096, 128) | noop | torch.bfloat16 | 3582.711 | 3362.623 | 10436.069 | 36415.565 | 1.065 | 0.287 | | (8, 16, 4096, 128) | causal_mask | torch.bfloat16 | 3581.504 | 3499.472 | 10346.869 | 36164.959 | 1.023 | 0.286 | | (8, 16, 4096, 128) | relative_bias | torch.bfloat16 | 3589.779 | 3337.849 | 10529.621 | 36261.696 | 1.075 | 0.290 | | (8, 16, 4096, 128) | head_bias | torch.bfloat16 | 3602.265 | 3436.444 | 10458.660 | 36507.790 | 1.048 | 0.286 | | (8, 16, 4096, 256) | noop | torch.bfloat16 | 7695.923 | 7126.275 | 24643.009 | 140949.081 | 1.080 | 0.175 | | (8, 16, 4096, 256) | causal_mask | torch.bfloat16 | 7679.939 | 7186.252 | 24538.105 | 157156.067 | 1.069 | 0.156 | | (8, 16, 4096, 256) | relative_bias | torch.bfloat16 | 7681.374 | 6994.832 | 24549.713 | 140077.179 | 1.098 | 0.175 | | (8, 16, 4096, 256) | head_bias | torch.bfloat16 | 7679.822 | 7212.278 | 24627.823 | 140675.003 | 1.065 | 0.175 | | (16, 16, 512, 64) | noop | torch.bfloat16 | 80.126 | 78.291 | 333.719 | 541.165 | 1.023 | 0.617 | | (16, 16, 512, 64) | causal_mask | torch.bfloat16 | 80.065 | 81.696 | 333.779 | 551.113 | 0.980 | 0.606 | | (16, 16, 512, 64) | relative_bias | torch.bfloat16 | 80.138 | 86.715 | 333.364 | 542.118 | 0.924 | 0.615 | | (16, 16, 512, 64) | head_bias | torch.bfloat16 | 80.415 | 85.204 | 333.294 | 536.840 | 0.944 | 0.621 | | (16, 16, 512, 128) | noop | torch.bfloat16 | 134.964 | 138.025 | 607.093 | 1333.102 | 0.978 | 0.455 | | (16, 16, 512, 128) | causal_mask | torch.bfloat16 | 134.192 | 141.523 | 606.269 | 1424.318 | 0.948 | 0.426 | | (16, 16, 512, 128) | relative_bias | torch.bfloat16 | 135.711 | 138.639 | 606.283 | 1327.974 | 0.979 | 0.457 | | (16, 16, 512, 128) | head_bias | torch.bfloat16 | 135.552 | 140.555 | 607.107 | 1347.370 | 0.964 | 0.451 | | (16, 16, 512, 256) | noop | torch.bfloat16 | 275.113 | 315.144 | 1301.583 | 5268.153 | 0.873 | 0.247 | | (16, 16, 512, 256) | causal_mask | torch.bfloat16 | 274.867 | 328.106 | 1302.513 | 5770.594 | 0.838 | 0.226 | | (16, 16, 512, 256) | relative_bias | torch.bfloat16 | 276.052 | 321.770 | 1302.904 | 5241.920 | 0.858 | 0.249 | | (16, 16, 512, 256) | head_bias | torch.bfloat16 | 271.409 | 328.839 | 1302.142 | 5266.037 | 0.825 | 0.247 | | (16, 16, 1024, 64) | noop | torch.bfloat16 | 260.489 | 237.463 | 955.884 | 1817.558 | 1.097 | 0.526 | | (16, 16, 1024, 64) | causal_mask | torch.bfloat16 | 262.378 | 254.350 | 955.280 | 1843.807 | 1.032 | 0.518 | | (16, 16, 1024, 64) | relative_bias | torch.bfloat16 | 261.338 | 268.253 | 956.038 | 1820.036 | 0.974 | 0.525 | | (16, 16, 1024, 64) | head_bias | torch.bfloat16 | 262.153 | 264.156 | 956.023 | 1810.076 | 0.992 | 0.528 | | (16, 16, 1024, 128) | noop | torch.bfloat16 | 476.475 | 461.413 | 1760.578 | 4306.521 | 1.033 | 0.409 | | (16, 16, 1024, 128) | causal_mask | torch.bfloat16 | 473.794 | 479.178 | 1761.277 | 4619.439 | 0.989 | 0.381 | | (16, 16, 1024, 128) | relative_bias | torch.bfloat16 | 473.839 | 463.282 | 1758.692 | 4290.562 | 1.023 | 0.410 | | (16, 16, 1024, 128) | head_bias | torch.bfloat16 | 472.979 | 472.896 | 1763.086 | 4367.931 | 1.000 | 0.404 | | (16, 16, 1024, 256) | noop | torch.bfloat16 | 1014.184 | 1026.764 | 3922.997 | 19104.147 | 0.988 | 0.205 | | (16, 16, 1024, 256) | causal_mask | torch.bfloat16 | 1013.217 | 1039.046 | 3928.382 | 21086.281 | 0.975 | 0.186 | | (16, 16, 1024, 256) | relative_bias | torch.bfloat16 | 1008.519 | 1015.278 | 3922.133 | 18980.652 | 0.993 | 0.207 | | (16, 16, 1024, 256) | head_bias | torch.bfloat16 | 1011.360 | 1047.542 | 3931.245 | 19069.172 | 0.965 | 0.206 | | (16, 16, 4096, 64) | noop | torch.bfloat16 | 3929.850 | 3325.667 | 11411.704 | 23344.280 | 1.182 | 0.489 | | (16, 16, 4096, 64) | causal_mask | torch.bfloat16 | 3885.262 | 3581.544 | 11390.515 | 23725.639 | 1.085 | 0.480 | | (16, 16, 4096, 64) | relative_bias | torch.bfloat16 | 3865.737 | 3537.308 | 11489.901 | 23406.330 | 1.093 | 0.491 | | (16, 16, 4096, 64) | head_bias | torch.bfloat16 | 3880.530 | 3665.249 | 11484.411 | 23299.496 | 1.059 | 0.493 | | (16, 16, 4096, 128) | noop | torch.bfloat16 | 7030.306 | 6745.715 | 20621.264 | 57464.096 | 1.042 | 0.359 | | (16, 16, 4096, 128) | causal_mask | torch.bfloat16 | 7095.414 | 7034.385 | 20410.656 | 61660.511 | 1.009 | 0.331 | | (16, 16, 4096, 128) | relative_bias | torch.bfloat16 | 7084.779 | 6686.497 | 20315.161 | 57243.969 | 1.060 | 0.355 | | (16, 16, 4096, 128) | head_bias | torch.bfloat16 | 7075.367 | 6863.305 | 20494.385 | 58481.953 | 1.031 | 0.350 | | (16, 16, 4096, 256) | noop | torch.bfloat16 | 15612.741 | 14297.482 | 55306.847 | 281161.865 | 1.092 | 0.197 | | (16, 16, 4096, 256) | causal_mask | torch.bfloat16 | 15326.592 | 14263.878 | 55227.806 | 313063.232 | 1.075 | 0.176 | | (16, 16, 4096, 256) | relative_bias | torch.bfloat16 | 15297.963 | 14007.379 | 54558.029 | 279529.175 | 1.092 | 0.195 | | (16, 16, 4096, 256) | head_bias | torch.bfloat16 | 15216.160 | 14276.027 | 55081.581 | 280996.826 | 1.066 | 0.196 | </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/125515 Approved by: https://github.com/Chillee |
||
|
|
337830f657 |
Revert "[inductor][cpp] GEMM template (infra and fp32) (#124021)"
This reverts commit
|
||
|
|
59ca0d8c14 |
Revert "[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion (#126068)"
This reverts commit
|
||
|
|
cef7756c9c |
[inductor] Clear cache on ctx manager exit (#126146)
FIXES https://github.com/pytorch/pytorch/issues/126128. Right now, we only clear the cache on ctx manager enter. So state is bad unless we call fresh_inductor_cache again, usually fine in tests. Cue compiled autograd tests when going from TestCompiledAutograd -> TestAutogradWithCompiledAutograd. TestCompiledAutograd uses the ctx manager, but TestAutogradWithCompiledAutograd don't Pull Request resolved: https://github.com/pytorch/pytorch/pull/126146 Approved by: https://github.com/jgong5, https://github.com/oulgen ghstack dependencies: #126144 |
||
|
|
82c66bc41a |
Make 'pytest test/inductor/test_memory_planning.py' work (#126397)
There's still another naughty direct test_* import, I'm out of patience right now though. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/126397 Approved by: https://github.com/peterbell10, https://github.com/int3 |
||
|
|
0716f75cfb |
Revert "Add Lowering for FlexAttention Backwards (#125515)"
This reverts commit |
||
|
|
08aa704d0c |
[1/N] Non-Tensor: Scalar Support: Enable aot compile to support aten operations with scalar input like alpha (#124177)
Some operations have a scalar input parameter, like `torch.add(a, b, alpha=2.0)`. Currently, the aot compile does not support such a case because it requires the signature of the captured graph to align with the operation's signature. This means that some inputs in the captured graph may be scalar(float, int, bool, etc.). It breaks the assumption of `compile_fx_aot` as it assumes all the example inputs are tensor -
|
||
|
|
95b9e981c3 |
Add Lowering for FlexAttention Backwards (#125515)
# Summary #### What does this PR do? It enables Inductor to actually generate the fused flex attention kernel for the backwards I did some other things along the way: - Abstract out the 'build_subgraph_buffer' subroutine and make it reusable between flex attention and flex_attention backwards. In total we need too build 3 subgraphs for fwd + bwd. 1 for the fwd graph and then 2 in the bwd. The FAv2 algorithm recomputes the parts of the forward (more efficiently since we already have the row_max via logsumexp), therefore we need to inline both the fwd graph and the joint graph in the bwds kernel. - The version of the backwards kernel is from a somewhat older version of the triton tutorial implementation. I think that we should update in a follow up to a newer version. Notably the blocks need to be square for this to work as currently implemented. I am sure there are many opportunities for optimization. - I didnt correctly register the decomp table + IndexMode when I landed: https://github.com/pytorch/pytorch/pull/123902, this remedies that. - The rel_bias helper func was reversed in terms of causality. I updated and then add a test specific for "future causal" attention. - This PRs but the main point that I think still needs to be worked out is the store_output call. I have it hacked up to be 'fake' but I dont think we want to land that and likely want to just have a mutated 'dq' and a stored_output 'dk' - I also needed to update the `TritonTemplateKernel` to actually accept multiple subgraphs (modifications) - I updated the benchmark to also profile bwds performance ### Benchmark Numbers: _The current implementation is not parallelizing over ctx length in the bwd_ FWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.991 | | | | | Max | 1.182 | (16, 16, 4096, 64) | noop | torch.bfloat16 | | Min | 0.796 | (2, 16, 512, 256) | head_bias | torch.bfloat16 | BWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.291 | | | | | Max | 0.652 | (8, 16, 512, 64) | head_bias | torch.bfloat16 | | Min | 0.073 | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | <details> <summary>Full Data</summary> | shape | score_mod | dtype | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup | |---------------------|---------------|----------------|------------------|---------------------|------------------|---------------------|---------------|---------------| | (2, 16, 512, 64) | noop | torch.bfloat16 | 19.936 | 19.092 | 57.851 | 193.564 | 1.044 | 0.299 | | (2, 16, 512, 64) | causal_mask | torch.bfloat16 | 19.955 | 19.497 | 57.662 | 206.278 | 1.024 | 0.280 | | (2, 16, 512, 64) | relative_bias | torch.bfloat16 | 19.455 | 21.297 | 57.674 | 195.219 | 0.913 | 0.295 | | (2, 16, 512, 64) | head_bias | torch.bfloat16 | 19.958 | 21.289 | 57.674 | 193.859 | 0.938 | 0.298 | | (2, 16, 512, 128) | noop | torch.bfloat16 | 28.157 | 28.615 | 82.831 | 454.211 | 0.984 | 0.182 | | (2, 16, 512, 128) | causal_mask | torch.bfloat16 | 28.154 | 28.444 | 83.091 | 432.083 | 0.990 | 0.192 | | (2, 16, 512, 128) | relative_bias | torch.bfloat16 | 28.722 | 27.897 | 83.175 | 446.789 | 1.030 | 0.186 | | (2, 16, 512, 128) | head_bias | torch.bfloat16 | 28.299 | 27.673 | 83.052 | 459.179 | 1.023 | 0.181 | | (2, 16, 512, 256) | noop | torch.bfloat16 | 41.167 | 50.504 | 175.019 | 1083.545 | 0.815 | 0.162 | | (2, 16, 512, 256) | causal_mask | torch.bfloat16 | 41.656 | 51.933 | 175.078 | 1171.176 | 0.802 | 0.149 | | (2, 16, 512, 256) | relative_bias | torch.bfloat16 | 41.697 | 50.722 | 175.159 | 1097.312 | 0.822 | 0.160 | | (2, 16, 512, 256) | head_bias | torch.bfloat16 | 41.690 | 52.387 | 175.184 | 1097.336 | 0.796 | 0.160 | | (2, 16, 1024, 64) | noop | torch.bfloat16 | 39.232 | 37.454 | 127.847 | 612.430 | 1.047 | 0.209 | | (2, 16, 1024, 64) | causal_mask | torch.bfloat16 | 39.930 | 39.599 | 127.755 | 665.359 | 1.008 | 0.192 | | (2, 16, 1024, 64) | relative_bias | torch.bfloat16 | 39.417 | 41.304 | 127.902 | 614.990 | 0.954 | 0.208 | | (2, 16, 1024, 64) | head_bias | torch.bfloat16 | 39.965 | 42.034 | 127.953 | 613.273 | 0.951 | 0.209 | | (2, 16, 1024, 128) | noop | torch.bfloat16 | 63.964 | 71.024 | 226.510 | 1637.669 | 0.901 | 0.138 | | (2, 16, 1024, 128) | causal_mask | torch.bfloat16 | 63.843 | 72.451 | 226.750 | 1558.949 | 0.881 | 0.145 | | (2, 16, 1024, 128) | relative_bias | torch.bfloat16 | 64.301 | 70.487 | 226.651 | 1610.063 | 0.912 | 0.141 | | (2, 16, 1024, 128) | head_bias | torch.bfloat16 | 64.033 | 71.394 | 226.676 | 1668.511 | 0.897 | 0.136 | | (2, 16, 1024, 256) | noop | torch.bfloat16 | 129.348 | 141.390 | 507.337 | 4405.175 | 0.915 | 0.115 | | (2, 16, 1024, 256) | causal_mask | torch.bfloat16 | 129.538 | 145.680 | 507.178 | 4768.874 | 0.889 | 0.106 | | (2, 16, 1024, 256) | relative_bias | torch.bfloat16 | 129.438 | 142.782 | 507.004 | 4401.002 | 0.907 | 0.115 | | (2, 16, 1024, 256) | head_bias | torch.bfloat16 | 129.058 | 146.242 | 507.547 | 4434.251 | 0.883 | 0.114 | | (2, 16, 4096, 64) | noop | torch.bfloat16 | 481.606 | 409.120 | 1440.890 | 14147.269 | 1.177 | 0.102 | | (2, 16, 4096, 64) | causal_mask | torch.bfloat16 | 480.227 | 438.847 | 1434.419 | 14973.386 | 1.094 | 0.096 | | (2, 16, 4096, 64) | relative_bias | torch.bfloat16 | 480.831 | 458.104 | 1432.935 | 14193.253 | 1.050 | 0.101 | | (2, 16, 4096, 64) | head_bias | torch.bfloat16 | 480.749 | 452.497 | 1437.040 | 14084.869 | 1.062 | 0.102 | | (2, 16, 4096, 128) | noop | torch.bfloat16 | 872.534 | 848.275 | 2600.895 | 35156.849 | 1.029 | 0.074 | | (2, 16, 4096, 128) | causal_mask | torch.bfloat16 | 872.647 | 868.279 | 2587.581 | 31919.531 | 1.005 | 0.081 | | (2, 16, 4096, 128) | relative_bias | torch.bfloat16 | 871.484 | 827.644 | 2593.989 | 34805.634 | 1.053 | 0.075 | | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | 871.422 | 856.437 | 2602.482 | 35708.591 | 1.017 | 0.073 | | (2, 16, 4096, 256) | noop | torch.bfloat16 | 1904.497 | 1758.183 | 6122.416 | 66754.593 | 1.083 | 0.092 | | (2, 16, 4096, 256) | causal_mask | torch.bfloat16 | 1911.174 | 1762.821 | 6113.207 | 72759.392 | 1.084 | 0.084 | | (2, 16, 4096, 256) | relative_bias | torch.bfloat16 | 1911.254 | 1727.108 | 6123.530 | 66577.988 | 1.107 | 0.092 | | (2, 16, 4096, 256) | head_bias | torch.bfloat16 | 1916.977 | 1801.804 | 6118.158 | 67359.680 | 1.064 | 0.091 | | (8, 16, 512, 64) | noop | torch.bfloat16 | 44.984 | 43.974 | 170.276 | 262.259 | 1.023 | 0.649 | | (8, 16, 512, 64) | causal_mask | torch.bfloat16 | 45.001 | 46.265 | 170.509 | 274.893 | 0.973 | 0.620 | | (8, 16, 512, 64) | relative_bias | torch.bfloat16 | 45.466 | 48.211 | 170.606 | 262.759 | 0.943 | 0.649 | | (8, 16, 512, 64) | head_bias | torch.bfloat16 | 45.481 | 48.435 | 170.267 | 261.265 | 0.939 | 0.652 | | (8, 16, 512, 128) | noop | torch.bfloat16 | 72.565 | 74.736 | 313.220 | 773.126 | 0.971 | 0.405 | | (8, 16, 512, 128) | causal_mask | torch.bfloat16 | 72.015 | 75.755 | 313.311 | 775.513 | 0.951 | 0.404 | | (8, 16, 512, 128) | relative_bias | torch.bfloat16 | 72.105 | 74.189 | 313.806 | 769.238 | 0.972 | 0.408 | | (8, 16, 512, 128) | head_bias | torch.bfloat16 | 72.005 | 74.364 | 313.509 | 775.237 | 0.968 | 0.404 | | (8, 16, 512, 256) | noop | torch.bfloat16 | 138.656 | 165.453 | 663.707 | 2672.067 | 0.838 | 0.248 | | (8, 16, 512, 256) | causal_mask | torch.bfloat16 | 139.096 | 172.613 | 663.593 | 2926.538 | 0.806 | 0.227 | | (8, 16, 512, 256) | relative_bias | torch.bfloat16 | 139.500 | 168.417 | 663.938 | 2658.629 | 0.828 | 0.250 | | (8, 16, 512, 256) | head_bias | torch.bfloat16 | 139.776 | 173.549 | 662.920 | 2667.266 | 0.805 | 0.249 | | (8, 16, 1024, 64) | noop | torch.bfloat16 | 134.883 | 125.004 | 484.706 | 1195.254 | 1.079 | 0.406 | | (8, 16, 1024, 64) | causal_mask | torch.bfloat16 | 134.297 | 132.875 | 485.420 | 1234.953 | 1.011 | 0.393 | | (8, 16, 1024, 64) | relative_bias | torch.bfloat16 | 134.839 | 139.231 | 485.470 | 1198.556 | 0.968 | 0.405 | | (8, 16, 1024, 64) | head_bias | torch.bfloat16 | 133.822 | 136.449 | 485.608 | 1189.198 | 0.981 | 0.408 | | (8, 16, 1024, 128) | noop | torch.bfloat16 | 235.470 | 234.765 | 886.094 | 2662.944 | 1.003 | 0.333 | | (8, 16, 1024, 128) | causal_mask | torch.bfloat16 | 236.305 | 241.382 | 886.293 | 2646.984 | 0.979 | 0.335 | | (8, 16, 1024, 128) | relative_bias | torch.bfloat16 | 236.414 | 233.980 | 885.250 | 2642.178 | 1.010 | 0.335 | | (8, 16, 1024, 128) | head_bias | torch.bfloat16 | 237.176 | 239.040 | 885.754 | 2665.242 | 0.992 | 0.332 | | (8, 16, 1024, 256) | noop | torch.bfloat16 | 504.445 | 517.855 | 1978.956 | 9592.906 | 0.974 | 0.206 | | (8, 16, 1024, 256) | causal_mask | torch.bfloat16 | 502.428 | 536.002 | 1978.611 | 10607.342 | 0.937 | 0.187 | | (8, 16, 1024, 256) | relative_bias | torch.bfloat16 | 503.396 | 523.960 | 1977.993 | 9539.284 | 0.961 | 0.207 | | (8, 16, 1024, 256) | head_bias | torch.bfloat16 | 503.818 | 536.014 | 1980.131 | 9576.262 | 0.940 | 0.207 | | (8, 16, 4096, 64) | noop | torch.bfloat16 | 1970.139 | 1674.930 | 5750.940 | 16724.134 | 1.176 | 0.344 | | (8, 16, 4096, 64) | causal_mask | torch.bfloat16 | 1959.036 | 1775.056 | 5780.512 | 17390.350 | 1.104 | 0.332 | | (8, 16, 4096, 64) | relative_bias | torch.bfloat16 | 1947.198 | 1773.869 | 5780.643 | 16779.699 | 1.098 | 0.345 | | (8, 16, 4096, 64) | head_bias | torch.bfloat16 | 1963.935 | 1829.502 | 5780.018 | 16703.259 | 1.073 | 0.346 | | (8, 16, 4096, 128) | noop | torch.bfloat16 | 3582.711 | 3362.623 | 10436.069 | 36415.565 | 1.065 | 0.287 | | (8, 16, 4096, 128) | causal_mask | torch.bfloat16 | 3581.504 | 3499.472 | 10346.869 | 36164.959 | 1.023 | 0.286 | | (8, 16, 4096, 128) | relative_bias | torch.bfloat16 | 3589.779 | 3337.849 | 10529.621 | 36261.696 | 1.075 | 0.290 | | (8, 16, 4096, 128) | head_bias | torch.bfloat16 | 3602.265 | 3436.444 | 10458.660 | 36507.790 | 1.048 | 0.286 | | (8, 16, 4096, 256) | noop | torch.bfloat16 | 7695.923 | 7126.275 | 24643.009 | 140949.081 | 1.080 | 0.175 | | (8, 16, 4096, 256) | causal_mask | torch.bfloat16 | 7679.939 | 7186.252 | 24538.105 | 157156.067 | 1.069 | 0.156 | | (8, 16, 4096, 256) | relative_bias | torch.bfloat16 | 7681.374 | 6994.832 | 24549.713 | 140077.179 | 1.098 | 0.175 | | (8, 16, 4096, 256) | head_bias | torch.bfloat16 | 7679.822 | 7212.278 | 24627.823 | 140675.003 | 1.065 | 0.175 | | (16, 16, 512, 64) | noop | torch.bfloat16 | 80.126 | 78.291 | 333.719 | 541.165 | 1.023 | 0.617 | | (16, 16, 512, 64) | causal_mask | torch.bfloat16 | 80.065 | 81.696 | 333.779 | 551.113 | 0.980 | 0.606 | | (16, 16, 512, 64) | relative_bias | torch.bfloat16 | 80.138 | 86.715 | 333.364 | 542.118 | 0.924 | 0.615 | | (16, 16, 512, 64) | head_bias | torch.bfloat16 | 80.415 | 85.204 | 333.294 | 536.840 | 0.944 | 0.621 | | (16, 16, 512, 128) | noop | torch.bfloat16 | 134.964 | 138.025 | 607.093 | 1333.102 | 0.978 | 0.455 | | (16, 16, 512, 128) | causal_mask | torch.bfloat16 | 134.192 | 141.523 | 606.269 | 1424.318 | 0.948 | 0.426 | | (16, 16, 512, 128) | relative_bias | torch.bfloat16 | 135.711 | 138.639 | 606.283 | 1327.974 | 0.979 | 0.457 | | (16, 16, 512, 128) | head_bias | torch.bfloat16 | 135.552 | 140.555 | 607.107 | 1347.370 | 0.964 | 0.451 | | (16, 16, 512, 256) | noop | torch.bfloat16 | 275.113 | 315.144 | 1301.583 | 5268.153 | 0.873 | 0.247 | | (16, 16, 512, 256) | causal_mask | torch.bfloat16 | 274.867 | 328.106 | 1302.513 | 5770.594 | 0.838 | 0.226 | | (16, 16, 512, 256) | relative_bias | torch.bfloat16 | 276.052 | 321.770 | 1302.904 | 5241.920 | 0.858 | 0.249 | | (16, 16, 512, 256) | head_bias | torch.bfloat16 | 271.409 | 328.839 | 1302.142 | 5266.037 | 0.825 | 0.247 | | (16, 16, 1024, 64) | noop | torch.bfloat16 | 260.489 | 237.463 | 955.884 | 1817.558 | 1.097 | 0.526 | | (16, 16, 1024, 64) | causal_mask | torch.bfloat16 | 262.378 | 254.350 | 955.280 | 1843.807 | 1.032 | 0.518 | | (16, 16, 1024, 64) | relative_bias | torch.bfloat16 | 261.338 | 268.253 | 956.038 | 1820.036 | 0.974 | 0.525 | | (16, 16, 1024, 64) | head_bias | torch.bfloat16 | 262.153 | 264.156 | 956.023 | 1810.076 | 0.992 | 0.528 | | (16, 16, 1024, 128) | noop | torch.bfloat16 | 476.475 | 461.413 | 1760.578 | 4306.521 | 1.033 | 0.409 | | (16, 16, 1024, 128) | causal_mask | torch.bfloat16 | 473.794 | 479.178 | 1761.277 | 4619.439 | 0.989 | 0.381 | | (16, 16, 1024, 128) | relative_bias | torch.bfloat16 | 473.839 | 463.282 | 1758.692 | 4290.562 | 1.023 | 0.410 | | (16, 16, 1024, 128) | head_bias | torch.bfloat16 | 472.979 | 472.896 | 1763.086 | 4367.931 | 1.000 | 0.404 | | (16, 16, 1024, 256) | noop | torch.bfloat16 | 1014.184 | 1026.764 | 3922.997 | 19104.147 | 0.988 | 0.205 | | (16, 16, 1024, 256) | causal_mask | torch.bfloat16 | 1013.217 | 1039.046 | 3928.382 | 21086.281 | 0.975 | 0.186 | | (16, 16, 1024, 256) | relative_bias | torch.bfloat16 | 1008.519 | 1015.278 | 3922.133 | 18980.652 | 0.993 | 0.207 | | (16, 16, 1024, 256) | head_bias | torch.bfloat16 | 1011.360 | 1047.542 | 3931.245 | 19069.172 | 0.965 | 0.206 | | (16, 16, 4096, 64) | noop | torch.bfloat16 | 3929.850 | 3325.667 | 11411.704 | 23344.280 | 1.182 | 0.489 | | (16, 16, 4096, 64) | causal_mask | torch.bfloat16 | 3885.262 | 3581.544 | 11390.515 | 23725.639 | 1.085 | 0.480 | | (16, 16, 4096, 64) | relative_bias | torch.bfloat16 | 3865.737 | 3537.308 | 11489.901 | 23406.330 | 1.093 | 0.491 | | (16, 16, 4096, 64) | head_bias | torch.bfloat16 | 3880.530 | 3665.249 | 11484.411 | 23299.496 | 1.059 | 0.493 | | (16, 16, 4096, 128) | noop | torch.bfloat16 | 7030.306 | 6745.715 | 20621.264 | 57464.096 | 1.042 | 0.359 | | (16, 16, 4096, 128) | causal_mask | torch.bfloat16 | 7095.414 | 7034.385 | 20410.656 | 61660.511 | 1.009 | 0.331 | | (16, 16, 4096, 128) | relative_bias | torch.bfloat16 | 7084.779 | 6686.497 | 20315.161 | 57243.969 | 1.060 | 0.355 | | (16, 16, 4096, 128) | head_bias | torch.bfloat16 | 7075.367 | 6863.305 | 20494.385 | 58481.953 | 1.031 | 0.350 | | (16, 16, 4096, 256) | noop | torch.bfloat16 | 15612.741 | 14297.482 | 55306.847 | 281161.865 | 1.092 | 0.197 | | (16, 16, 4096, 256) | causal_mask | torch.bfloat16 | 15326.592 | 14263.878 | 55227.806 | 313063.232 | 1.075 | 0.176 | | (16, 16, 4096, 256) | relative_bias | torch.bfloat16 | 15297.963 | 14007.379 | 54558.029 | 279529.175 | 1.092 | 0.195 | | (16, 16, 4096, 256) | head_bias | torch.bfloat16 | 15216.160 | 14276.027 | 55081.581 | 280996.826 | 1.066 | 0.196 | </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/125515 Approved by: https://github.com/Chillee |
||
|
|
927e631dc2 |
[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion (#126068)
As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126068 Approved by: https://github.com/jansel ghstack dependencies: #126019 |
||
|
|
c53e0ac7ba |
[Inductor] Generalize new introduced device-bias code. (#126261)
We find some Inductor test case failues when enabling Inductor UT for Intel GPU, the root cause is new introduced Inductor device-bias code from recent community PRs, which cause differnet beheaviors between Intel GPU and CUDA. This PR generalize these codes to align their beheaviors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126261 Approved by: https://github.com/EikanWang, https://github.com/peterbell10 |
||
|
|
f060b0c6e6 |
[inductor][cpp] GEMM template (infra and fp32) (#124021)
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC https://github.com/pytorch/pytorch/issues/125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x Pull Request resolved: https://github.com/pytorch/pytorch/pull/124021 Approved by: https://github.com/jansel |
||
|
|
b6d8b256e6 |
Revert "[inductor][cpp] GEMM template (infra and fp32) (#124021)"
This reverts commit
|
||
|
|
d1f254dce8 |
Add a cache mechanism to accelerate torch.compile-for-eager (#116368)
This PR is a follow-up of RFC https://github.com/pytorch/pytorch/issues/115545.
In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation.
The persistent config consists of two parts - meta_info and kernel_path.
- meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag.
- kernel_path: The path of the kernel produced by Inductor.
When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value.
Currently, this PR only supports static shape to guard the kernel.
Take a `mul` as an example.
```python
class MulKernel:
def __init__(self) -> None:
pass
def __call__(self, *args: Any, **kwargs: Any) -> Any:
with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False):
opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={
"aot_inductor.eager_mode": True,
"aot_inductor.eager_op_name": "mul_Tensor"
}
)
return opt_fn(*args, **kwargs)
torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL")
_, overload_names = torch._C._jit_get_operation("aten::mul")
schema = torch._C._get_schema("aten::mul", overload_name)
reg_name = schema.name
if schema.overload_name:
reg_name = f"{reg_name}.{schema.overload_name}"
torch_compile_op_lib_impl.impl(
reg_name,
MulKernel(),
"CUDA",
compile_mode=True)
a = torch.randn(1024, 1024, device=device)
b = torch.randn(1024, 1024, device=device)
warm_up_iter = 1000
iter = 10000
fn = torch.mul
# Warm up
for _ in range(warm_up_iter):
fn(a, b)
# Collect performance
beg = time.time()
for _ in range(iter):
fn(a, b)
end = time.time()
print(f"E2E run: {end - beg}")
```
It will produce the config as follows.
```json
[
{
"meta_info": [
{
"is_symbolic": false,
"device_type": "cuda",
"dtype": "torch.float32",
"sizes": [1024, 1024],
"strides": [1024, 1]
},
{
"is_symbolic": false,
"device_type": "cuda",
"dtype": "torch.float32",
"sizes": [1024, 1024],
"strides": [1024, 1]
}
],
"kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so"
}
]
```
Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape.
- Eager: ~266.11ms
- W/O Cache: ~3455.54ms
- W/ Cache and Cache Miss: ~3555.3ms
- W/ Cache and Cache Hit: ~267.12ms
Hardware:
- CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz
- GPU: CUDA A10
Software:
- PyTorch Version:
|
||
|
|
037615b989 |
[inductor][cpp] GEMM template (infra and fp32) (#124021)
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC https://github.com/pytorch/pytorch/issues/125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x Pull Request resolved: https://github.com/pytorch/pytorch/pull/124021 Approved by: https://github.com/jansel |
||
|
|
320af5eaa6 |
Compute bounds for the variables created during codegen (#123100)
Before we would just bail out on these bounds for all variables that did not come from the FX graph. Now we propagate the bounds whenever we have a rule for that op. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123100 Approved by: https://github.com/jgong5, https://github.com/peterbell10 |
||
|
|
2a42c40791 |
Revert "Compute bounds for the variables created during codegen (#123100)"
This reverts commit |
||
|
|
7ffa5558ee |
Revert "[FX] Update type hints in torch.fx._compatibility.py (#125469)"
This reverts commit
|
||
|
|
bb668c6468 |
Compute bounds for the variables created during codegen (#123100)
Before we would just bail out on these bounds for all variables that did not come from the FX graph. Now we propagate the bounds whenever we have a rule for that op. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123100 Approved by: https://github.com/jgong5, https://github.com/peterbell10 |
||
|
|
58d8388ed3 |
Remove Inductor IRs for legacy functional collectives (#124992)
This PR completely removes the Inductor IR for legacy functional collectives: - Removed the `CollectiveKernel` hiearchy and `Wait`, as well as the corresponding lowerings. These IRs are target (i.e. Python) specific and don't model node dependencies propoerly (e.g. they rely on `never_reuse_buffers` for correct behavior). They've been superceded by `ir._CollectiveKernel`. - Removed `InPlaceHint` and the scheduler logic for handling it. `InPlaceHint` is a codegen-time buffer reuse mechanism controlled by the IR's codegen. It's a bit hacky and overlaps with the default buffer reuse mechanism. Removing it since it is only used by legacy functional collectives. - Removed `OutputBuffer` and `MultiOutputNoSizeAssert` which are designed for and only used by legacy functional collectives. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124992 Approved by: https://github.com/Chillee, https://github.com/wanchaol |
||
|
|
235b4d6ec2 |
[FX] Update type hints in torch.fx._compatibility.py (#125469)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125469 Approved by: https://github.com/Skylion007 ghstack dependencies: #125468 |
||
|
|
6f70d22277 |
Extend torch.utils._sympy.symbol for more Inductor symbols (#125419)
I'm still missing a few, cdzq at least Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/125419 Approved by: https://github.com/lezcano ghstack dependencies: #125395 |
||
|
|
46f326eff5 |
explicitly reset stderr/stdout in precompilation (#125289)
I was seeing a weird bug where after running max-autotune my stdout would be misdirected. other people have not been able to repro this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125289 Approved by: https://github.com/shunting314, https://github.com/mlazos |