mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
90d5a6f001
230 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
90d5a6f001 |
[inductor] Add lowering and codegen for aten.sort (#128458)
Closes #125633 Benchmarks: | Shape | dim | stable | compiled | eager | speedup | |-------------|-----|--------|----------|---------|---------| | (256, 4096) | 0 | False | 0.73 ms | 1.26 ms | 1.7 | | (256, 4096) | 0 | True | 0.75 ms | 1.27 ms | 1.7 | | (4096, 256) | 1 | False | 0.20 ms | 0.73 ms | 3.7 | | (4096, 256) | 1 | True | 0.21 ms | 0.73 ms | 3.5 | | (255, 4096) | 0 | False | 1.05 ms | 1.48 ms | 1.4 | | (255, 4096) | 0 | True | 1.03 ms | 1.47 ms | 1.4 | | (4096, 255) | 1 | False | 0.52 ms | 0.98 ms | 1.9 | | (4096, 255) | 1 | True | 0.54 ms | 1.00 ms | 1.9 | Pull Request resolved: https://github.com/pytorch/pytorch/pull/128458 Approved by: https://github.com/lezcano, https://github.com/eellison |
||
|
|
533c4190f9 |
[inductor][cpp] support nested kernel with indirect indexing (#129223)
This PR makes sure the current kernel is used for generating CSE variables when nested kernel codegen is involved, e.g., nested CppKernel is used to generate epilogue of CppTemplateKernel. Without the fix, the epilogue with indirect indexing would fail to run.
pytest -k test_linear_with_embedding_bias_False_cpu test_cpu_select_algorithm.py
Epilogue code Before:
```c++
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(m_end + ((-1L)*m_start)); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1+=static_cast<long>(16L))
{
auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)];
auto tmp11 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<long>(x1 + (N0*x0)), 16);
auto tmp1 = 64L;
auto tmp2 = c10::convert<int64_t>(tmp1);
auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
auto tmp4 = tmp0 ? tmp3 : tmp0;
auto tmp5 = decltype(tmp4)(tmp4 + tmp2);
auto tmp6 = tmp1 ? tmp5 : tmp4;
auto tmp7 = tmp6;
auto tmp8 = c10::convert<int64_t>(tmp7);
TORCH_CHECK((0 <= tmp8) & (tmp8 < 64L), "index out of bounds: 0 <= tmp8 < 64L");
auto tmp10 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<long>(n_start + x1 + (384L*tmp6)), 16);
auto tmp12 = (tmp11);
auto tmp13 = tmp10 + tmp12;
tmp13.store(Y + static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0)));
}
#pragma omp simd simdlen(8)
for(long x1=static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1<static_cast<long>(N0); x1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)];
auto tmp11 = local_acc_buf[static_cast<long>(x1 + (N0*x0))];
auto tmp1 = 64L;
auto tmp2 = c10::convert<int64_t>(tmp1);
auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
auto tmp4 = tmp0 ? tmp3 : tmp0;
auto tmp5 = decltype(tmp4)(tmp4 + tmp2);
auto tmp6 = tmp1 ? tmp5 : tmp4;
auto tmp7 = tmp6;
auto tmp8 = c10::convert<int64_t>(tmp7);
TORCH_CHECK((0 <= tmp8) & (tmp8 < 64L), "index out of bounds: 0 <= tmp8 < 64L");
TORCH_CHECK((0 <= tmp8) & (tmp8 < 64L), "index out of bounds: 0 <= tmp8 < 64L");
auto tmp10 = in_ptr3[static_cast<long>(n_start + x1 + (384L*tmp6))];
auto tmp12 = c10::convert<float>(tmp11);
auto tmp13 = decltype(tmp10)(tmp10 + tmp12);
Y[static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0))] = tmp13;
}
}
}
```
Epilogue code After:
```c++
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(m_end + ((-1L)*m_start)); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1+=static_cast<long>(16L))
{
auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)];
auto tmp13 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<long>(x1 + (N0*x0)), 16);
auto tmp1 = 64L;
auto tmp2 = c10::convert<int64_t>(tmp1);
auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
auto tmp4 = tmp0 < 0;
auto tmp5 = tmp4 ? tmp3 : tmp0;
auto tmp6 = decltype(tmp5)(tmp5 + tmp2);
auto tmp7 = tmp5 < 0;
auto tmp8 = tmp7 ? tmp6 : tmp5;
auto tmp9 = tmp8;
auto tmp10 = c10::convert<int64_t>(tmp9);
TORCH_CHECK((0 <= tmp10) & (tmp10 < 64L), "index out of bounds: 0 <= tmp10 < 64L");
auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<long>(n_start + x1 + (384L*tmp8)), 16);
auto tmp14 = (tmp13);
auto tmp15 = tmp12 + tmp14;
tmp15.store(Y + static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0)));
}
#pragma omp simd simdlen(8)
for(long x1=static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1<static_cast<long>(N0); x1+=static_cast<long>(1L))
{
auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)];
auto tmp13 = local_acc_buf[static_cast<long>(x1 + (N0*x0))];
auto tmp1 = 64L;
auto tmp2 = c10::convert<int64_t>(tmp1);
auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
auto tmp4 = tmp0 < 0;
auto tmp5 = tmp4 ? tmp3 : tmp0;
auto tmp6 = decltype(tmp5)(tmp5 + tmp2);
auto tmp7 = tmp5 < 0;
auto tmp8 = tmp7 ? tmp6 : tmp5;
auto tmp9 = tmp8;
auto tmp10 = c10::convert<int64_t>(tmp9);
TORCH_CHECK((0 <= tmp10) & (tmp10 < 64L), "index out of bounds: 0 <= tmp10 < 64L");
TORCH_CHECK((0 <= tmp10) & (tmp10 < 64L), "index out of bounds: 0 <= tmp10 < 64L");
auto tmp12 = in_ptr3[static_cast<long>(n_start + x1 + (384L*tmp8))];
auto tmp14 = c10::convert<float>(tmp13);
auto tmp15 = decltype(tmp12)(tmp12 + tmp14);
Y[static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0))] = tmp15;
}
}
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129223
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
|
||
|
|
1a54bb0f96 |
Revert "[halide-backend] Initial implementation of HalideKernel and HalideScheduling (#126417)"
This reverts commit
|
||
|
|
063facf352 |
Revert "[halide-backend] Generate standalone runtime (#129025)"
This reverts commit
|
||
|
|
10c64c3b49 |
[halide-backend] Generate standalone runtime (#129025)
This puts the halide runtime in a global shared object, rather than copying it to each kernel. Having many copies of the runtime causes many issues with cuda. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129025 Approved by: https://github.com/shunting314, https://github.com/eellison ghstack dependencies: #126417 |
||
|
|
4f9399bd0d |
[halide-backend] Initial implementation of HalideKernel and HalideScheduling (#126417)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126417 Approved by: https://github.com/shunting314, https://github.com/eellison |
||
|
|
feb3f3ad77 |
[inductor] Refactors for Halide backend (#129024)
Pulling these inductor-related refactors out of the larger Halide backend PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129024 Approved by: https://github.com/shunting314, https://github.com/eellison |
||
|
|
c187593418 |
Prevent expansion of cat indexing to avoid int64 intermediate (#127815)
Fix for https://github.com/pytorch/pytorch/issues/127652 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127815 Approved by: https://github.com/shunting314, https://github.com/peterbell10 |
||
|
|
e397ad6883 |
Improve codegen for ops.masked in triton (#128054)
Fixes https://github.com/pytorch/pytorch/issues/127930 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128054 Approved by: https://github.com/peterbell10, https://github.com/lezcano |
||
|
|
c897651392 |
[inductor] Add BackendFeature gating (#128266)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128266 Approved by: https://github.com/shunting314 |
||
|
|
f2dcbe89d6 |
Revert "Prevent expansion of cat indexing to avoid int64 intermediate (#127815)"
This reverts commit
|
||
|
|
793df7b7cb |
Prevent expansion of cat indexing to avoid int64 intermediate (#127815)
Fix for https://github.com/pytorch/pytorch/issues/127652 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127815 Approved by: https://github.com/shunting314, https://github.com/peterbell10 |
||
|
|
3964a3ec73 |
Complete revamp of float/promotion sympy handling (#126905)
At a high level, the idea behind this PR is: * Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.) * Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers. The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions: * FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing). * ModularIndexing, LShift, RShift now assert they are given integer inputs. * Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver * TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division. * Trunc is split to TruncToFloat and TruncToInt. * Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result. * RoundDecimal updated to consistently only ever return a float * Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing) In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information. We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**: * `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy * `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv` These changes have consequences. First, we need to make some administrative changes: * Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2) * Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py** * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here * Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet * Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions. In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments: * Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now * `_assert_bound_is_rational` is no more, we no longer generate rational bounds * Don't intersect non-int value ranges with the `int_range` * Support more sympy Functions for guard SYMPY_INTERP * Assert the type of value range is consistent with the variable type The new asserts uncovered necessary bug fixes: * **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions * **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions * **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr! * **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1 Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py** **Reland notes.** This requires this internal fbcode diff https://www.internalfb.com/phabricator/paste/view/P1403322587 but I cannot prepare the diff codev due to https://fb.workplace.com/groups/osssupport/posts/26343544518600814/ It also requires this Executorch PR https://github.com/pytorch/executorch/pull/3911 but the ET PR can be landed prior to this landing. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905 Approved by: https://github.com/xadupre, https://github.com/lezcano |
||
|
|
ea614fb2b1 |
Flip default value for mypy disallow_untyped_defs [2/11] (#127839)
See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127839 Approved by: https://github.com/oulgen |
||
|
|
ac51f782fe |
Revert "Complete revamp of float/promotion sympy handling (#126905)"
This reverts commit
|
||
|
|
2f7cfecd86 |
Complete revamp of float/promotion sympy handling (#126905)
At a high level, the idea behind this PR is: * Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.) * Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers. The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions: * FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing). * ModularIndexing, LShift, RShift now assert they are given integer inputs. * Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver * TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division. * Trunc is split to TruncToFloat and TruncToInt. * Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result. * RoundDecimal updated to consistently only ever return a float * Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing) In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information. We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**: * `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy * `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv` These changes have consequences. First, we need to make some administrative changes: * Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2) * Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py** * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here * Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet * Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions. In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments: * Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now * `_assert_bound_is_rational` is no more, we no longer generate rational bounds * Don't intersect non-int value ranges with the `int_range` * Support more sympy Functions for guard SYMPY_INTERP * Assert the type of value range is consistent with the variable type The new asserts uncovered necessary bug fixes: * **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions * **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions * **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr! * **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1 Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py** Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905 Approved by: https://github.com/xadupre, https://github.com/lezcano |
||
|
|
d5cb5d623a |
Revert "Complete revamp of float/promotion sympy handling (#126905)"
This reverts commit
|
||
|
|
fb696ef3aa |
Complete revamp of float/promotion sympy handling (#126905)
At a high level, the idea behind this PR is: * Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.) * Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers. The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions: * FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing). * ModularIndexing, LShift, RShift now assert they are given integer inputs. * Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver * TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division. * Trunc is split to TruncToFloat and TruncToInt. * Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result. * RoundDecimal updated to consistently only ever return a float * Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing) In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information. We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**: * `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy * `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv` These changes have consequences. First, we need to make some administrative changes: * Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2) * Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py** * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here * Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet * Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions. In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments: * Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now * `_assert_bound_is_rational` is no more, we no longer generate rational bounds * Don't intersect non-int value ranges with the `int_range` * Support more sympy Functions for guard SYMPY_INTERP * Assert the type of value range is consistent with the variable type The new asserts uncovered necessary bug fixes: * **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions * **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions * **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr! * **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1 Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py** Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905 Approved by: https://github.com/xadupre, https://github.com/lezcano |
||
|
|
a4064da8ca |
Always simplify sympy expressions before printing. (#127543)
This is important because if a replacement has happened during inductor lowering, we may have stale symbols in sympy expressions that we need to replace away. Do this at the very end. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/127543 Approved by: https://github.com/lezcano |
||
|
|
0fa2c5b049 |
Fix mask propagation in the presence of where (#125574)
Before, when calling ops.where, masks were not properly propagated. We now restrict the optimisation to `ops.masked`, which I think it was what the original code intended to do. I'm not 100% sure that even in the masked case this code is not introducing some bugs, but this is a strict improvement over the previous state. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125574 Approved by: https://github.com/peterbell10 ghstack dependencies: #114471, #126783 |
||
|
|
92bc444ee3 |
[inductor][cpp] epilogue support for gemm template (#126019)
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126019 Approved by: https://github.com/jansel ghstack dependencies: #124021 |
||
|
|
8a21532e53 |
Fix constant propagation pass (#114471)
This pass was broken in a number of ways, as we were not generating asserts whenever we took it, even though we need to. While doing so, we found that the analysis we were using for choosing whether to generate asserts or not for dynamic shapes was completely broken. Eliminating indirect indexing in this way allows for a number of optimisations. In particular, we can now fuse against these kernels (indirect indexing disallows fusions). The new strategy is as follows: - We always propagate sympy expressions if we can. - If an expression was an indirect_indexing, we call `check_bounds` - We also call `check_bounds` within `CSEProxy.indirect_indexing` - The checks are issued in the buffer where they would go if the were used in a load - This makes them always be codegen'd before the load and stores - In the case of stores, they will be generated potentially much earlier than the stores themselves, which is fine. We add quite a few asserts to preexisting tests to strengthen them. In particular, we make sure that issuing an assert plays well with all kinds of C++ vectorisation. For now, we rely on the logic within `_maybe_evaluate_static` to prove these bounds. This logic is rather limited though. In the future, we might want to rely on Z3 here to be able to prove bounds in a more general way. Supersedes https://github.com/pytorch/pytorch/pull/113068 Fixes https://github.com/pytorch/pytorch/issues/121251 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114471 Approved by: https://github.com/peterbell10 |
||
|
|
343a41fba8 |
Revert "[inductor][cpp] epilogue support for gemm template (#126019)"
This reverts commit
|
||
|
|
92433217cb |
[inductor] Misc refactors (#126945)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126945 Approved by: https://github.com/shunting314 ghstack dependencies: #126944 |
||
|
|
56c412d906 |
[inductor][cpp] epilogue support for gemm template (#126019)
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126019 Approved by: https://github.com/jansel ghstack dependencies: #124021 |
||
|
|
45784cd229 |
Revert "[inductor][cpp] epilogue support for gemm template (#126019)"
This reverts commit
|
||
|
|
08f57b4bff |
[inductor][cpp] epilogue support for gemm template (#126019)
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126019 Approved by: https://github.com/jansel ghstack dependencies: #124021 |
||
|
|
657d39e44c |
Revert "[inductor][cpp] epilogue support for gemm template (#126019)"
This reverts commit |
||
|
|
57108d9a49 |
[inductor][cpp] epilogue support for gemm template (#126019)
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126019 Approved by: https://github.com/jansel ghstack dependencies: #124021 |
||
|
|
14c5c753de |
[inductor] use smaller RBLOCK for expensive reduction kernels (#126477)
Triton sometimes uses less registers for more expensive kernel which results in worse perf ( https://github.com/pytorch/pytorch/issues/126463 ). This may make inductor end up with a sub-optimal config. Use a smaller max RBLOCK if the reduction potentially need many registers. Will run perf test.. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126477 Approved by: https://github.com/jansel |
||
|
|
b40fb2de59 |
[AOTI] Fix a codegen issue when .item() is used for kernel arg (#126575)
Summary: fixes https://github.com/pytorch/pytorch/issues/126574 . Pass kernel argument type information into generate_args_decl, so it can generate the argument declaration instead of relying on string matching. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126575 Approved by: https://github.com/chenyang78 ghstack dependencies: #126369 |
||
|
|
55033ab43a |
Update ops handler documentation some more (#126480)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/126480 Approved by: https://github.com/peterbell10 ghstack dependencies: #126292, #126299 |
||
|
|
4a5ef0b793 |
Revert "[inductor][cpp] epilogue support for gemm template (#126019)"
This reverts commit
|
||
|
|
7844c202b2 |
[inductor][cpp] epilogue support for gemm template (#126019)
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126019 Approved by: https://github.com/jansel |
||
|
|
2ba102f689 |
Implement native support for float inputs in Dynamo and ShapeEnv (#125325)
The big idea is that floats are treated as Tensors on input/output to the FX graph, but on the inside, we immediately call item() on the synthetic Tensor and record regular float operations on it. Canonicalization to Tensor operations will happen in a standalone FX pass. This behavior is controlled by `specialize_float` config variable when set to False.
The generated graph looks like this for the test `test_unspec_float_output`:
```
def forward(self, L_x_: "f32[3]", L_y_: "f32[]"):
l_x_ = L_x_
l_y_ = L_y_
# File: /data/users/ezyang/a/pytorch/test/dynamo/test_unspec.py:511 in f, code: return x + 1, y * 2
add: "f32[3]" = l_x_ + 1; l_x_ = None
item: "Sym(zf0)" = l_y_.item(); l_y_ = None
mul: "Sym(2*zf0)" = item * 2; item = None
scalar_tensor: "f32[]" = torch.scalar_tensor(mul); mul = None
return (add, scalar_tensor)
```
The ingredients:
* **torch/_dynamo/variables/builder.py** When `specialize_float` is False, we wrap float literals with `wrap_symfloat`. This is an unholy mashup of `wrap_symint` and `wrap_unspecialized_primitive`. The overall strategy is that we first generate a tensor argument (because that's what we want to show up into the FX graph), but then immediately call item() on the tensor argument to get a SymNodeVariable, which we will do the rest of the tracing with. Importantly, this SymNodeVariable is backed with the source of the original float: this means we can guard on the resulting value (something we could NOT do with UnspecializedPythonVariable). This has to be done manually, because if you literally call item() on the tensor, you will end up with an unbacked float. There is a bit of copy paste from wrap_symint and wrap_unspecialized_primitive which we can try to factor out, but this really is its own thing and you should review every line of code in the function.
* **torch/fx/experimental/symbolic_shapes.py** We now can generate guards on float inputs, and these guards are handled inside of ShapeEnv. So we need to be able to allocate (backed!) float symbols, and produce guards for them. Fairly straightforward generalization.
* **torch/_dynamo/codegen.py** I also need to maintain the invariant that there are no float outputs to the FX graph. I chose to do this at codegen time. When we detect a SymNodeVariable on the return stack for a float, we on the fly convert it (via `as_tensor`) to a TensorVariable, which is the true output. We then special case the output bytecode to call item() on it again. The tensor conversion is memoized on SymNodeVariable since we typically run the code generation process twice.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125325
Approved by: https://github.com/lezcano, https://github.com/jansel
|
||
|
|
320af5eaa6 |
Compute bounds for the variables created during codegen (#123100)
Before we would just bail out on these bounds for all variables that did not come from the FX graph. Now we propagate the bounds whenever we have a rule for that op. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123100 Approved by: https://github.com/jgong5, https://github.com/peterbell10 |
||
|
|
2a42c40791 |
Revert "Compute bounds for the variables created during codegen (#123100)"
This reverts commit |
||
|
|
bb668c6468 |
Compute bounds for the variables created during codegen (#123100)
Before we would just bail out on these bounds for all variables that did not come from the FX graph. Now we propagate the bounds whenever we have a rule for that op. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123100 Approved by: https://github.com/jgong5, https://github.com/peterbell10 |
||
|
|
68a1f787c8 |
[inductor][cpp] move some common cpp utils to cpp_utils.py (#125152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125152 Approved by: https://github.com/desertfire, https://github.com/jansel |
||
|
|
6f70d22277 |
Extend torch.utils._sympy.symbol for more Inductor symbols (#125419)
I'm still missing a few, cdzq at least Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/125419 Approved by: https://github.com/lezcano ghstack dependencies: #125395 |
||
|
|
5503c29357 |
Introduce torch.utils._sympy.symbol (#125395)
This provides utilities for creating and querying properties on sympy.Symbol. I want to use this refactor to get a better handle on how the 's' prefix is being used in Inductor. To start, I only do symbolic_shapes code because that's what I'm familiar with. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/125395 Approved by: https://github.com/Skylion007 |
||
|
|
dae574c713 |
Don't make replacements for i variables (#125398)
This was introduced in https://github.com/pytorch/pytorch/pull/110262 but actually it looks like they were trying to hit unbacked SymInt. Now that unbacked SymInt is renamed to u, this code is no longer necessary Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/125398 Approved by: https://github.com/lezcano, https://github.com/Skylion007 |
||
|
|
c5b1a4c269 |
[inductor] share more cse cache during swap buffer (#124921)
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope.
For example,
```
auto tmp8 = -std::numeric_limits<float>::infinity();
auto tmp9 = [&]
{
auto tmp12 = -std::numeric_limits<float>::infinity();
return tmp12;
}
```
`tmp12` should not be created since it is same with `tmp8`.
We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.)
**Test Plan**
```
python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu
```
the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope.
Before this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr1)
{
#pragma omp parallel num_threads(128)
{
int tid = omp_get_thread_num();
{
#pragma omp for collapse(2)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
{
for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
{
auto tmp0 = c10::convert<int>(x1);
auto tmp1 = static_cast<int>(256);
auto tmp2 = tmp0 < tmp1;
auto tmp3 = [&]
{
auto tmp4 = c10::convert<int>(x3);
auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
auto tmp6 = static_cast<int>(257);
auto tmp7 = at::vec::Vectorized<int>(tmp6);
auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
auto tmp11 = tmp8 & tmp10;
auto tmp9 = [&]
{
auto tmp12 = -std::numeric_limits<float>::infinity();
return tmp12;
}
;
auto tmp13 =
[&]
{
if (tmp11.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
}
}
()
;
auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
auto tmp15 = static_cast<int>(3);
auto tmp16 = tmp14 < tmp15;
auto tmp18 = tmp16 & tmp2;
auto tmp17 = [&]
{
auto tmp19 = c10::convert<int>(x3);
auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1);
auto tmp21 = static_cast<int>(256);
auto tmp22 = at::vec::Vectorized<int>(tmp21);
auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22);
auto tmp25 = at::vec::VecMask<float,1>::from(tmp18);
auto tmp26 = tmp23 & tmp25;
auto tmp24 = [&]
{
auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
return tmp27;
}
;
auto tmp28 =
[&]
{
if (tmp26.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>());
}
}
()
;
auto tmp29 = static_cast<float>(0.0);
auto tmp30 = at::vec::Vectorized<float>(tmp29);
auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>());
return tmp31;
}
;
auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
auto tmp33 = static_cast<float>(0.0);
auto tmp34 = at::vec::VecMask<float,1>::from(tmp16);
auto tmp35 = at::vec::Vectorized<float>(tmp33);
auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>());
auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>());
return tmp37;
}
;
auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
auto tmp40 = static_cast<int>(3);
auto tmp41 = tmp39 < tmp40;
auto tmp42 = [&]
{
auto tmp43 = c10::convert<int>(x3);
auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1);
auto tmp45 = static_cast<int>(256);
auto tmp46 = at::vec::Vectorized<int>(tmp45);
auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46);
auto tmp49 = at::vec::VecMask<float,1>::from(tmp41);
auto tmp50 = tmp47 & tmp49;
auto tmp48 = [&]
{
auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
return tmp51;
}
;
auto tmp52 =
[&]
{
if (tmp50.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>());
}
}
()
;
auto tmp53 = static_cast<float>(0.0);
auto tmp54 = at::vec::Vectorized<float>(tmp53);
auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>());
return tmp55;
}
;
auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0));
auto tmp57 = static_cast<float>(0.0);
auto tmp58 = at::vec::VecMask<float,1>::from(tmp41);
auto tmp59 = at::vec::Vectorized<float>(tmp57);
auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>());
auto tmp61 = at::vec::VecMask<float,1>::from(tmp2);
auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>());
tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
}
#pragma omp simd simdlen(8)
for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
{
auto tmp0 = c10::convert<int64_t>(x1);
auto tmp1 = static_cast<int64_t>(256);
auto tmp2 = tmp0 < tmp1;
auto tmp3 = [&]
{
auto tmp4 = c10::convert<int64_t>(x3);
auto tmp5 = static_cast<int64_t>(257);
auto tmp6 = tmp4 < tmp5;
auto tmp7 = [&]
{
auto tmp8 = -std::numeric_limits<float>::infinity();
return tmp8;
}
;
auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
auto tmp11 = static_cast<int64_t>(3);
auto tmp12 = tmp10 < tmp11;
auto tmp13 = [&]
{
auto tmp14 = c10::convert<int64_t>(x3);
auto tmp15 = static_cast<int64_t>(256);
auto tmp16 = tmp14 >= tmp15;
auto tmp17 = [&]
{
auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
return tmp18;
}
;
auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0);
auto tmp20 = static_cast<float>(0.0);
auto tmp21 = tmp16 ? tmp19 : tmp20;
return tmp21;
}
;
auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
auto tmp23 = static_cast<float>(0.0);
auto tmp24 = tmp12 ? tmp22 : tmp23;
auto tmp25 = tmp6 ? tmp9 : tmp24;
return tmp25;
}
;
auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
auto tmp28 = static_cast<int64_t>(3);
auto tmp29 = tmp27 < tmp28;
auto tmp30 = [&]
{
auto tmp31 = c10::convert<int64_t>(x3);
auto tmp32 = static_cast<int64_t>(256);
auto tmp33 = tmp31 >= tmp32;
auto tmp34 = [&]
{
auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
return tmp35;
}
;
auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0);
auto tmp37 = static_cast<float>(0.0);
auto tmp38 = tmp33 ? tmp36 : tmp37;
return tmp38;
}
;
auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0);
auto tmp40 = static_cast<float>(0.0);
auto tmp41 = tmp29 ? tmp39 : tmp40;
auto tmp42 = tmp2 ? tmp26 : tmp41;
out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42;
}
}
}
}
}
}
}
''')
```
After this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr1)
{
#pragma omp parallel num_threads(128)
{
int tid = omp_get_thread_num();
{
#pragma omp for collapse(2)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
{
for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
{
auto tmp0 = c10::convert<int>(x1);
auto tmp1 = static_cast<int>(256);
auto tmp2 = tmp0 < tmp1;
auto tmp3 = [&]
{
auto tmp4 = c10::convert<int>(x3);
auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
auto tmp6 = static_cast<int>(257);
auto tmp7 = at::vec::Vectorized<int>(tmp6);
auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
auto tmp11 = tmp8 & tmp10;
auto tmp9 = [&]
{
auto tmp12 = -std::numeric_limits<float>::infinity();
return tmp12;
}
;
auto tmp13 =
[&]
{
if (tmp11.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
}
}
()
;
auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
auto tmp15 = static_cast<int>(3);
auto tmp16 = tmp14 < tmp15;
auto tmp18 = tmp16 & tmp2;
auto tmp17 = [&]
{
auto tmp19 = at::vec::Vectorized<int>(tmp1);
auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19);
auto tmp22 = at::vec::VecMask<float,1>::from(tmp18);
auto tmp23 = tmp20 & tmp22;
auto tmp21 = [&]
{
auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
return tmp24;
}
;
auto tmp25 =
[&]
{
if (tmp23.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>());
}
}
()
;
auto tmp26 = static_cast<float>(0.0);
auto tmp27 = at::vec::Vectorized<float>(tmp26);
auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>());
return tmp28;
}
;
auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
auto tmp30 = static_cast<float>(0.0);
auto tmp31 = at::vec::VecMask<float,1>::from(tmp16);
auto tmp32 = at::vec::Vectorized<float>(tmp30);
auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>());
auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>());
return tmp34;
}
;
auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
auto tmp37 = static_cast<int>(3);
auto tmp38 = tmp36 < tmp37;
auto tmp39 = [&]
{
auto tmp40 = c10::convert<int>(x3);
auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1);
auto tmp42 = at::vec::Vectorized<int>(tmp1);
auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42);
auto tmp45 = at::vec::VecMask<float,1>::from(tmp38);
auto tmp46 = tmp43 & tmp45;
auto tmp44 = [&]
{
auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
return tmp47;
}
;
auto tmp48 =
[&]
{
if (tmp46.all_zero())
{
return at::vec::Vectorized<float>(static_cast<float>(0.0));
}
else
{
return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>());
}
}
()
;
auto tmp49 = static_cast<float>(0.0);
auto tmp50 = at::vec::Vectorized<float>(tmp49);
auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>());
return tmp51;
}
;
auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0));
auto tmp53 = static_cast<float>(0.0);
auto tmp54 = at::vec::VecMask<float,1>::from(tmp38);
auto tmp55 = at::vec::Vectorized<float>(tmp53);
auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>());
auto tmp57 = at::vec::VecMask<float,1>::from(tmp2);
auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>());
tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
}
#pragma omp simd simdlen(8)
for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
{
auto tmp0 = c10::convert<int64_t>(x1);
auto tmp1 = static_cast<int64_t>(256);
auto tmp2 = tmp0 < tmp1;
auto tmp3 = [&]
{
auto tmp4 = c10::convert<int64_t>(x3);
auto tmp5 = static_cast<int64_t>(257);
auto tmp6 = tmp4 < tmp5;
auto tmp7 = [&]
{
auto tmp8 = -std::numeric_limits<float>::infinity();
return tmp8;
}
;
auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
auto tmp11 = static_cast<int64_t>(3);
auto tmp12 = tmp10 < tmp11;
auto tmp13 = [&]
{
auto tmp14 = tmp4 >= tmp1;
auto tmp15 = [&]
{
auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
return tmp16;
}
;
auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0);
auto tmp18 = static_cast<float>(0.0);
auto tmp19 = tmp14 ? tmp17 : tmp18;
return tmp19;
}
;
auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
auto tmp21 = static_cast<float>(0.0);
auto tmp22 = tmp12 ? tmp20 : tmp21;
auto tmp23 = tmp6 ? tmp9 : tmp22;
return tmp23;
}
;
auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
auto tmp26 = static_cast<int64_t>(3);
auto tmp27 = tmp25 < tmp26;
auto tmp28 = [&]
{
auto tmp29 = c10::convert<int64_t>(x3);
auto tmp30 = tmp29 >= tmp1;
auto tmp31 = [&]
{
auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
return tmp32;
}
;
auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0);
auto tmp34 = static_cast<float>(0.0);
auto tmp35 = tmp30 ? tmp33 : tmp34;
return tmp35;
}
;
auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0);
auto tmp37 = static_cast<float>(0.0);
auto tmp38 = tmp27 ? tmp36 : tmp37;
auto tmp39 = tmp2 ? tmp24 : tmp38;
out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39;
}
}
}
}
}
}
}
''')
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124921
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: #124597
|
||
|
|
f0f7452e31 |
Do not propogate (#124769)
Fix the propogate typos. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124769 Approved by: https://github.com/Skylion007 |
||
|
|
9a5b4d2403 |
Do not forward parent's value range to CSE variable for variables created within codegen. (#123099)
Consider we are generating code for `ops.gt`, and within it we call `ops.to_dtype`. Before, we would forward the bounds from `gt` to the to the result of `to_dtype`, which is wrong. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123099 Approved by: https://github.com/jgong5, https://github.com/peterbell10 |
||
|
|
93e249969b |
[BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261 Approved by: https://github.com/albanD |
||
|
|
bd225189f1 |
[inductor] Change OverridesData to take callables instead of strings (#123397)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123397 Approved by: https://github.com/lezcano |
||
|
|
efa36ef092 |
Natively support int truncation, don't guard on positive/negative (#122827)
This doesn't entirely fix the original problem that prompted this, but it seems to just be getting stuck in export constraint formatting now which seems like progress to me. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/122827 Approved by: https://github.com/avikchaudhuri |
||
|
|
9189d04cb1 |
[inductor] Add explicit ops.fma and use it in softmax_backward (#122518)
This allows us to generate an fma even when fp-fusion is disabled in the compiler. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122518 Approved by: https://github.com/lezcano, https://github.com/Chillee |
||
|
|
f4e2a226aa |
ScoreMod API (#121845)
# Summary This PR adds a new higher-order_op: `templated_attention`. This op is designed to extend the functionality of torch.nn.fucntional.scaled_dot_product_attention. PyTorch has efficient pre-written fused-attention kernels. However, users want to modify how scores are computed (a substep inside attention) -- this traditionally requires the user to write their own attention kernel. One such modification to attention scores that is not currently supported by the top level SDPA op is:[ Attention with Linear Biases (ALiBi](https://arxiv.org/abs/2108.12409)). This higher-order op will instead accept a callable( 'score_mod') function that is through torch.compile will be used to create an efficient attention kernel instantiation. ### Details This HOP utilizes the existing fx and HOP infra to capture and convert the User `score-mod` function and convert to an FX graph module. Inductor then consumes this HOP that has a `ir.Subgraph` input. It will inline this lowered subgraph into a triton kernel which performs fused attention with the modification to the scores matrix inlined. ### API The API for a score_mod function should be as follows: ```Python def score_mod(score: torch.Tensor, batch: torch.Tensor, head: torch.Tensor, token_1: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor ``` This function receives five parameters: - `score`: A scalar tensor representing the attention score, with the same data type and device as the query, key, and value tensors. - `batch`, `head`, `seq_len_q`, `seq_len_kv`: Scalar tensors indicating the batch index, head index, query index, and key/value index, respectively, with torch.int data type and located on the same device as the score tensor. Consider inputs query, key, value of shapes (2, 4, 16, 8), leading to an intermediate attention score matrix of shape (2, 4, 16, 16) The score_mod function will be vectorized over each element of this matrix. For instance, modifying the score at the position corresponding to the 0th batch, 2nd head, between the 8th query and the 9th key element, would be invoked as: ```Python score_mod(score[0,2,8,9], torch.tensor(0), torch.tensor(2), torch.tensor(8), torch.tensor(9)) ``` ### Examples ```Python import torch from torch.nn.attention.templated_attention import templated_attention torch.manual_seed(0) # Lets create some input tensors # The input tensor has shape (batch_size, num_heads, seq_len, head_dim) query = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32) key = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32) value = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32) # Lets create a fun new score_modification! I will call this # Checkerboard. It will reduce the score for neighboring tokens (1 step apart) # in the sequence. And increase the score for tokens 2 steps apart. For everything # else, the score will remain the same. def checkerboard(score, batch, head, token_q, token_kv): score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score) score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score) return score # Lets call templated_attention with this new score modification output = templated_attention(query, key, value, score_mod=checkerboard) compiled_templated_attention = torch.compile(templated_attention) out_compiled = compiled_templated_attention(query, key, value, score_mod=checkerboard) torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2) ``` ### Future Work - This PR is currently only forward only. However the triton kernel for backwards where score_modifications to not rely on external buffers has been explored here: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/flash/flash_attention.py - Kernel Improvements; There are has been some larger updates to the fused attention implementation that Triton uses in its tutorials. The implementation of this kernel is based on a prior version and should be updated. - We may want to unify this API under the top level SDPA API and leave that as a follow up once this is more stable - Should we error on CPU? - There are some issues with dynamic shapes - Capturing of free variables and lifting to inputs to the subgraph is not working correctly today ### Performance Comparisons generated by this benchmark: | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------| | Average | 5.412 | | | | | | | | | Max | 8.882 | 16 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | | Min | 3.645 | 8 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | | Min | 0.345 | 1 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | For reference | Configuration | Forward Time (µ seconds) | Backend | Speedup | |-----------------------------------------------|--------------------------|------------------|---------| | Fastest Config in Sweep (`8 16 4096 4096 64 relative_bias torch.bfloat16`) | 3608 | Templated Attention | 1.0 | | Compiled SDPA (No Mask) | 9928 | Math | 2.75x | | Compiled SDPA (With Mask) | 11898 | Math | 3.29x | | Compiled SDPA (With Mask) | 8704 | Memory Efficient Attention | 2.42x | | Compiled SDPA (No Mask) | 2548 | FlashAttention2 | 0.706x | The speedups are measuring compiled templated attention speed versus different calls to torch.nn.functional.sdpa <details> <summary> FULL PERFORMANCE SWEEP NUMBERS </summary> | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | eager_time | compiled_time | speedup | |--------------|-------------|-------------|-------------|------------|---------------|----------------|--------------|-----------------|-----------| | 1 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | 331.444 | 67.221 | 4.931 | | 1 | 16 | 512 | 512 | 64 | relative_bias | torch.bfloat16 | 335.300 | 64.187 | 5.224 | | 1 | 16 | 512 | 512 | 64 | head_bias | torch.bfloat16 | 352.039 | 63.806 | 5.517 | | 1 | 16 | 512 | 512 | 64 | pathological | torch.bfloat16 | 371.699 | 711.349 | 0.523 | | 1 | 16 | 1024 | 1024 | 64 | causal_mask | torch.bfloat16 | 333.488 | 86.455 | 3.857 | | 1 | 16 | 1024 | 1024 | 64 | relative_bias | torch.bfloat16 | 322.363 | 82.469 | 3.909 | | 1 | 16 | 1024 | 1024 | 64 | head_bias | torch.bfloat16 | 349.967 | 82.233 | 4.256 | | 1 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | 486.359 | 1412.453 | 0.344 | | 1 | 16 | 4096 | 4096 | 64 | causal_mask | torch.bfloat16 | 2794.597 | 551.188 | 5.070 | | 1 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | 3965.150 | 513.101 | 7.728 | | 1 | 16 | 4096 | 4096 | 64 | head_bias | torch.bfloat16 | 2408.013 | 504.759 | 4.771 | | 1 | 16 | 4096 | 4096 | 64 | pathological | torch.bfloat16 | 6850.531 | 16733.675 | 0.409 | | 8 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | 441.939 | 123.576 | 3.576 | | 8 | 16 | 512 | 512 | 64 | relative_bias | torch.bfloat16 | 560.379 | 116.710 | 4.801 | | 8 | 16 | 512 | 512 | 64 | head_bias | torch.bfloat16 | 421.172 | 115.825 | 3.636 | | 8 | 16 | 512 | 512 | 64 | pathological | torch.bfloat16 | 994.492 | 2132.806 | 0.466 | | 8 | 16 | 1024 | 1024 | 64 | causal_mask | torch.bfloat16 | 1436.430 | 309.495 | 4.641 | | 8 | 16 | 1024 | 1024 | 64 | relative_bias | torch.bfloat16 | 1892.216 | 290.186 | 6.521 | | 8 | 16 | 1024 | 1024 | 64 | head_bias | torch.bfloat16 | 1360.665 | 282.956 | 4.809 | | 8 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | 3525.532 | 8359.702 | 0.422 | | 8 | 16 | 4096 | 4096 | 64 | causal_mask | torch.bfloat16 | 22026.839 | 3864.604 | 5.700 | | 8 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | 31262.746 | 3609.551 | 8.661 | | 8 | 16 | 4096 | 4096 | 64 | head_bias | torch.bfloat16 | 20219.079 | 3480.402 | 5.809 | | 8 | 16 | 4096 | 4096 | 64 | pathological | torch.bfloat16 | 54654.647 | 116652.357 | 0.469 | | 16 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | 820.606 | 188.683 | 4.349 | | 16 | 16 | 512 | 512 | 64 | relative_bias | torch.bfloat16 | 1058.362 | 179.295 | 5.903 | | 16 | 16 | 512 | 512 | 64 | head_bias | torch.bfloat16 | 784.372 | 175.714 | 4.464 | | 16 | 16 | 512 | 512 | 64 | pathological | torch.bfloat16 | 1890.792 | 4212.877 | 0.449 | | 16 | 16 | 1024 | 1024 | 64 | causal_mask | torch.bfloat16 | 2781.830 | 557.017 | 4.994 | | 16 | 16 | 1024 | 1024 | 64 | relative_bias | torch.bfloat16 | 3694.050 | 525.249 | 7.033 | | 16 | 16 | 1024 | 1024 | 64 | head_bias | torch.bfloat16 | 2634.164 | 507.613 | 5.189 | | 16 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | 6959.917 | 15331.116 | 0.454 | | 16 | 16 | 4096 | 4096 | 64 | causal_mask | torch.bfloat16 | 43889.096 | 7582.018 | 5.789 | | 16 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | 62784.293 | 7075.846 | 8.873 | | 16 | 16 | 4096 | 4096 | 64 | head_bias | torch.bfloat16 | 40308.606 | 6829.587 | 5.902 | | 16 | 16 | 4096 | 4096 | 64 | pathological | torch.bfloat16 | 108892.137 | 233090.953 | 0.467 | </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/121845 Approved by: https://github.com/Chillee, https://github.com/zou3519 |