Summary:
This PR adds a lowering for `torch._cslt_sparse_mm` to find the optimal
alg_id and cache it when running with `torch.compile`
Seeing speedups on both bfloat16 and float8 dtypes:
<img width="641" alt="Screenshot 2024-10-17 at 2 10 38 PM" src="https://github.com/user-attachments/assets/b928cd11-32a3-43e5-b209-8e4028896f0b">
<img width="1274" alt="Screenshot 2024-10-17 at 1 39 03 PM" src="https://github.com/user-attachments/assets/d9edd684-a8ec-46fd-b3da-2e76dbcb7bb6">
* `torch._cslt_sparse_mm_search` has been modified to return optimal
split-k parameters as well as max alg_id.
* max_id is now available in `torch.backends.cusparselt` via
`torch.backends.cusparselt.get_max_alg_id()`
* fixed meta registrations for float8
Test Plan:
python test/test_sparse_semi_structured.py
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427
Approved by: https://github.com/cpuhrsch
This PR adds new meta functions for `lerp`, `addcmul`, and `addcdiv` (including their
respective inplace versions).
These functions only had refs implementations, which was being the root cause of a
significant overhead ([issue][1]) when running `AdamW` optimizer step on PyTorch/XLA
backend. Running the meta functions resulted in the following improvements:
- `lerp` calls: 1,550ms to 140ms (10x)
- `addcdiv` calls: 640ms to 350ms (1.8x)
- `addcmul` calls: 620ms to 300ms (2.05x)
[1]: https://github.com/pytorch/xla/issues/7923
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136909
Approved by: https://github.com/jansel
Summary:
# Latest Update
This diff is no longer needed because we did need the check to exist, to make meta behave the same as other devices, see D54526190.
---------------------------------
# Background
T176105639
| case | embedding bag weight | per_sample_weight | fbgemm lookup | forward in meta |
| A | fp32 | fp32 | good | good |
| B | fp16 | fp32 | good| failed [check](https://fburl.com/code/k3n3h031) that forces weight dtype == per_sample_weights dtype |
| C | fp16 | fp16 | P1046999270, RuntimeError: "expected scalar type Float but found Half from fbgemm call" | good |
| D | fp32 | fp16 | N/A | N/A |
Currently we are in case A. Users need to add `use_fp32_embedding` in training to force embedding bag dtype to be fp32. However, users actually hope for case B to use fp16 as the embedding bag weight. When deleting `use_fp32_embedding`, they would fail the [check](https://fburl.com/code/k3n3h031) that forces `weight dtype == per_sample_weights dtype ` in meta_registration.
The check is actually not necessary. Is it because the backend fbgemm does support case B. Additionally, later on in the `meta_embedding_bag`, `weight` and `per_sample_weights` don't need to be in the same dtype (https://fburl.com/code/q0tho05h, weight is src, per_sample_weights is scale) for `is_fast_path_index_select`.
# This diff
Therefore, this diff remove the unnecessary [check](https://fburl.com/code/k3n3h031) to support case B in meta forward. With such, users are able to use fp16 to be the emb bag dtype without the need to force per_sample_weights the same dtype in meta forward (see Test Plan).
# Reference diffs to resolve this issue
Diff 1: D52591217
This passes embedding bag dtype to feature_processor to make per_sample_weights same dtype as emb bag weight. However, `is_meta` also needs to be passed because of case C. fbgemm still does not support per_sample_weights = fp16 (see the above table). Therefore users are forced to only make per_sample_weights fp16 when it is on meta. The solution requires too many hacks.
Diff 2: D53232739
Basically doing the same thing in diff 1 D52591217, except that the hack is added in TorchRec library. This adds an if in EBC and PEA for: when emb bag weight is fp16, it forces per_sample_weight fp16 too. However, it would then result in fbgemm issue too and has broken a bunch of prod models.
Test Plan:
# APS
The following command will run icvr_launcher which triggers ads_launcher and run forward in meta device:
```
buck2 run mode/opt -c python.package_style=inplace //aps_models/ads/icvr:icvr_launcher_publish -- mode=mast_ig_fm_when_combo0_uhm_publish launcher.fbl_entitlement=ads_global_tc_ads_score launcher.data_project=oncall_ads_model_platform launcher.tags=[ads_ranking_taxonomy_exlarge_fm_prod] stages.train=false
```
Result:
{F1461463993}
Reviewed By: ezyang
Differential Revision: D54175438
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136774
Approved by: https://github.com/ezyang
Partially addresses https://github.com/pytorch/pytorch/issues/128150
When you have big sums of values, we end up computing long chains of
binary addition in our FX graph representation. Not only is this ugly,
it also is quadratic, as the sympy.Add constructor is O(N) in number
of arguments. Instead, ensure that we maintain the summation as a
single FX node so we can do the entire addition all in one go.
update_hint_regression benchmark, before and after:
```
update_hint_regression,compile_time_instruction_count,2648328980
update_hint_regression,compile_time_instruction_count,2563748678
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136429
Approved by: https://github.com/isuruf
Adds lowering for `aten.searchsorted`. This entails:
1. Adding support for multi-dimensional bucket tensors to `ops.bucketize`.
2. Adding support for striding to `ops.bucketize`.
3. Adding support for sorting tensors to `ops.bucketize`.
4. Adding a lowering for `aten.searchsorted.Tensor`.
5. Adding a basic decomposition for `aten.searchsorted.Scalar` that calls into the lowering for tensors.
6. Updating the meta-function for `aten.searchsorted` to properly check some of the sizing conditions.
Closes#135873
Differential Revision: [D63766514](https://our.internmc.facebook.com/intern/diff/D63766514)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135701
Approved by: https://github.com/amjames, https://github.com/eellison, https://github.com/davidberard98
This PR adds new meta functions for `lerp`, `addcmul`, and `addcdiv` (including their
respective inplace versions).
These functions only had refs implementations, which was being the root cause of a
significant overhead ([issue][1]) when running `AdamW` optimizer step on PyTorch/XLA
backend. Running the meta functions resulted in the following improvements:
- `lerp` calls: 1,550ms to 140ms (10x)
- `addcdiv` calls: 640ms to 350ms (1.8x)
- `addcmul` calls: 620ms to 300ms (2.05x)
[1]: https://github.com/pytorch/xla/issues/7923
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136909
Approved by: https://github.com/jansel
Fixes #127049
There's already a meta func in `meta_registrations.py` for `add_` and `sub_` methods. I added a second meta function for error checking, i.e `int.add/sub_(float)` and `bool.add/sub_(other types)` .
Also the corresponding test with Dynamo passes, removed `@xfailIfTorchDynamo`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135864
Approved by: https://github.com/williamwen42
Fixes#136090
* Add support for isin to tensor half dtypes for CPU (just add a few extra dispatches).
* Seems like the CUDA implementation for bfloat16 was mostly compiled and available all along (it just calls sort internally AND unique). To enable it, we just need to remove an assert to access it (since sort's functionality was updated since the assert was added) and add missing dtype support to unique.
* This unlocks more GPU functionality with minimal code bloat. I also added CPU kernels for the dtypes for parity.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136114
Approved by: https://github.com/malfet
This PR:
* Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
* Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
* Note: there is currently no public API for this; design booted to a future PR
TODO:
* ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~
* ~~Verify that Inductor does computation fusion via test logic~~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125947
Approved by: https://github.com/soulitzer
See #121528 for additional context.
In #120682, we moved the attention kernels from meta_registrations to fake_impls with the intent of fixing the device handling for seed/offset: these are typically on CPU. We needed to put the registrations in fake_impls to do this because meta_registrations doesn't have a way to specify device, whereas fake_impls does. But when we tried to actually fix the device types (#120839), we had to revert the PR because it broke cudagraph handling (during which seed/offset _are_ on CUDA).
Now, we want to put the registrations back in meta_registrations so that we can call these kernels with meta tensors. The use case is later in this stack - we want to be able to use the flop counter with these kernels.
Also - I specifically skip the `compare_tensor_meta()` check in test_fake / test_fake_autocast tests for the `_efficient_attention_forward` and `_flash_attention_forward` kernels, which fails because of the device mismatch from the seed/offset tensors. Then we can un-skip these opinfos. I verified that the efficient_attention_forward bug (#120842) is now caught by these opinfos if I revert the fix from this PR.
Differential Revision: [D61687369](https://our.internmc.facebook.com/intern/diff/D61687369)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134288
Approved by: https://github.com/drisspg
Because aten.poisson doesn't have meta function registered, there is one additional eager execution of this op during compilation phase of torch.compile.
There are more ops without meta registration. Is there any reason for it?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134103
Approved by: https://github.com/ezyang
Add the Inductor lowering for `torch._scaled_mm`, whose API was last updated in https://github.com/pytorch/pytorch/pull/128683.
The lowering does:
- for tensor-wise scaling, auto-tune between the default ATen kernel (cuBLAS) and Triton kernel configurations.
- for row-wise scaling, auto-tune between the default ATen kernel (CUTLASS kernel added in https://github.com/pytorch/pytorch/pull/125204) and Triton kernel configurations.
The Triton kernel template is based on 3ad9031d02 (D56337896) by @choutim, without using SPLIT_K, and that of mm `torch/_inductor/kernel/mm.py`
## Testing:
- Logging shows max-autotune tuning (`AUTOTUNE scaled_mm`) for both tensor-wise and row-wise scaling when called with the two scaling types.
- Row-wise scaling allows operator fusion between preceding pointwise/reduction op and amax/cast:
- output code Evaluating m=256, n=256, k=256, fusion_case='pointwise', scaling_mode='row'
- P1477224245 - 2 kernels
- output code Evaluating m=2048, n=256, k=2048, fusion_case='reduction', scaling_mode='row'
- P1477227340 - 2 kernels
- UT `python test/inductor/test_fp8.py -- TestFP8Lowering`
## Benchmarking
Eager/compiled tensor-wise/row-wise scaling for various shapes:
https://docs.google.com/spreadsheets/d/1VfWEVuyrwoWysfbS0_u2VHJ-PsdWkF1qIsiD60AzTes/edit?gid=2113587669#gid=2113587669
- Some of the “compiled” cases are slightly slower than “eager”. It’s because max-autotune selected the ATen kernel in the compiled case, and I think the discrepancy is variance.
Eager/compiled tensor-wise/row-wise scaling with pointwise/reduction preceding op for various shapes:
https://docs.google.com/spreadsheets/d/1Nv07NrdffQIoDeMjo9E0V-E-EYrEN0WysO_bn1bc6ns/edit?gid=1715488446#gid=1715488446
## Questions for reviewers:
- Should the type of the accumulator `ACC_TYPE` always be in float32? If not, where is this type set (output layout?)?
## Todo:
- Make the Triton template use the improved persistent kernel version (https://github.com/pytorch/FBGEMM/pull/2735 by @htyu)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130422
Approved by: https://github.com/ipiszy
This PR is to update the input `weight` of `_convert_weight_to_int4pack` from `[n][k] int32` to `[n][k / 2] uint8`, both for CPU, CUDA and MPS, which can help decouple int4 model checkpoint with different ISAs and different platforms in `gpt-fast`. The advantage is int4 model checkpoint can be shared in different test machines, without re-generating in one certain platform. Meanwhile, the size of input `weight` can be reduced to `1 / 8`.
Before this PR, packed weight stored in CUDA specific layout: `[n/8][k/(InnerKTiles*16)][32][InnerKTiles/2]`, dtype int32, where InnerKTiles = 2, 4, 8. CPU packed weight viewed as the SAME shape but stored in different layout: `[n/64][k][32]`, dtype uint8. Weight is strongly coupled with platforms (CPU/CUDA) and ISAs (AVX512/AVX2/scalar). And users cannot use a generated weight in another different ISA or platform, because when loading weight into devices, the compute format is different.

Now, we use common serialized layout (`[n][k/2] uint8`) for different devices or ISAs as input `weight` of `_convert_weight_to_int4pack`, and each back chooses how to interpret as compute layout.

### Performance
Intel (R) Xeon (R) CPU Max 9480, single socket (56 cores)
There is no obvious regression of this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129940
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/mingfeima
This PR is to update the input `weight` of `_convert_weight_to_int4pack` from `[n][k] int32` to `[n][k / 2] uint8`, both for CPU, CUDA and MPS, which can help decouple int4 model checkpoint with different ISAs and different platforms in `gpt-fast`. The advantage is int4 model checkpoint can be shared in different test machines, without re-generating in one certain platform. Meanwhile, the size of input `weight` can be reduced to `1 / 8`.
Before this PR, packed weight stored in CUDA specific layout: `[n/8][k/(InnerKTiles*16)][32][InnerKTiles/2]`, dtype int32, where InnerKTiles = 2, 4, 8. CPU packed weight viewed as the SAME shape but stored in different layout: `[n/64][k][32]`, dtype uint8. Weight is strongly coupled with platforms (CPU/CUDA) and ISAs (AVX512/AVX2/scalar). And users cannot use a generated weight in another different ISA or platform, because when loading weight into devices, the compute format is different.

Now, we use common serialized layout (`[n][k/2] uint8`) for different devices or ISAs as input `weight` of `_convert_weight_to_int4pack`, and each back chooses how to interpret as compute layout.

### Performance
Intel (R) Xeon (R) CPU Max 9480, single socket (56 cores)
There is no obvious regression of this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129940
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/mingfeima
This PR modifies `_embedding_bag_backward` item inside _native_functions.yaml_, so that it
dispatches to CPU and CUDA directly, instead of `CompositeImplicitAutograd`.
*Context:* PyTorch operations that have the `CompositeImplicitAutograd` dispatch do not
allow third party backends (e.g. XLA) to modify its implementation, since this dispatch
key has higher priority. When calling `_embedding_bag_backward` operation using XLA, a
dispatch error will be thrown, since PyTorch/XLA doesn't support sparse tensors.
*Problem:* `_embedding_bag_backward` has a `sparse` parameter that controls whether the
operation should return a sparse or dense tensor. However, at the moment, PyTorch/XLA does
not support sparse tensors. In order to fallback that execution to dense, i.e. change the
flag at runtime, we need to be able to modify its implementation.
*Solution:* we have changed the dispatch of `_embedding_bag_backward` to CPU and CUDA,
which allowed us to introduce our own kernel for it.
Additionally, this PR refactored the representation of its mode from constant integers
into an enum class. It also introduces two additional operators: `int == EmbeddingBagMode`
and `int != EmbeddingBagMode`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129691
Approved by: https://github.com/lezcano
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
Summary:
To fix the following failure cases:
For example, when `M, K, N = 245760, 656, 6560`, fp8 with compile fails due to `RuntimeError: mat2 must be col_major`.
---------
From the inductor generated code (https://fburl.com/everpaste/epcagkrd)
```
V0625 01:38:55.551000 140329914449920 torch/_inductor/scheduler.py:1623] [0/0] scheduling ComputedBuffer(name='buf12', layout=FixedLayout('cuda', torch.float8_e4m3fn, size=[656, 6560], stride=[6656, 1]),
... ...
V0625 01:38:56.194000 140329914449920 torch/_inductor/graph.py:1680] [0/0] [__output_code] buf12 = empty_strided_cuda((656, 6560), (6656, 1), torch.float8_e4m3fn)
... ...
V0625 01:38:56.194000 140329914449920 torch/_inductor/graph.py:1680] [0/0] [__output_code] return (buf10, buf2, buf5, buf6, reinterpret_tensor(buf11, (245760, 656), (1, 245760), 0), reinterpret_tensor(buf12, (6560, 656), (1, 6656), 0), )
... ...
V0625 01:39:12.098000 140312968167424 torch/_inductor/graph.py:1680] [1/0_1] [__output_code] assert_size_stride(permute_10, (6560, 656), (1, 6656))
... ...
V0625 01:39:12.098000 140312968167424 torch/_inductor/graph.py:1680] [1/0_1] [__output_code] buf8 = aten._scaled_mm.default(buf6, permute_10, buf7, reciprocal_3, None, None, torch.bfloat16)
```
Inductor gives the mat2 (`permute_10`) a different stride (`6656`) instead of using its shape[0] (`(6560, 656)`).
Therefore, the `stride[1] == shape[0]` condition fails.
To fix the issue, simply modify the `is_col_major` check to exclude this condition as it doesn't hold for all valid cases.
Test Plan:
Run the failed case again. It works with the fix.
-----
Sandcastle / GitHub CI will make sure the existing tests could still pass.
Reviewed By: vkuzo
Differential Revision: D58994704
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129521
Approved by: https://github.com/drisspg
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.
What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...
Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
# Summary
The primary reason for the change was lack of current use case and the need to work around an two Inductor issue.
- Tensor arguments as kwarg only
- multiple outputs from triton templates
If the need for the amax return type arises we can consider either adding it, more likely creating a separate op.
In principle PyTorch is moving away from ops that bundle lots of functionality into "mega ops". We instead rely upon the compiler to generate appropriate fused kernels.
### Changes:
- This removes the amax return type from scaled_mm. We have found that the common use case is to return in "high-precision" ( a type with more precision than fp8). This is only relevant when returning in low-precision.
- We currently still allow for fp8 returns and scaled result. Perhaps we should also ban this as well...
New signature:
```Python
def meta_scaled_mm(
self: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: Optional[torch.Tensor] = None,
scale_result: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
) -> torch.Tensor:
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128683
Approved by: https://github.com/vkuzo
This PR lifts internal lowerings written for FBGEMM kernels that do jagged <-> padded dense conversions. In particular, this PR provides lowerings and meta registrations for the following ATen ops:
* `_jagged_to_padded_dense_forward()`
* `_padded_dense_to_jagged_forward()`
* NB: if `total_L` is not provided, the output shape is data-dependent. An unbacked SymInt is used for this case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125968
Approved by: https://github.com/davidberard98
Backporting a few fixes from xFormers:
* Bug fixes for local attention (which is not exposed in PT at the moment)
* Massively reduced memory usage on the BW pass (see also https://github.com/facebookresearch/xformers/pull/1028)
Essentially this will also make xFormers build process much easier, as we will be able to use mem-eff from PyTorch (if the user has a recent enough version) rather than building it at xFormers install time
The goal is to have the source of truth for these files in PT moving forward, and remove them from xFormers eventually once our users have a recent-enough version of PT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127090
Approved by: https://github.com/drisspg
This PR adds _foreach_max support, the second reduction foreach op we have :D
I did have to change the autogen slightly for foreach. I can promise that the existing foreach ops' derivative behavior has not changed as I've added a skip list for the harder requirement I am setting (that the arg list should match in length). I needed to add this requirement as there is another wrong max (the one that does take in a dim for reduction) that keeps getting matched first.
Caveats!
- We do not fast path if the shapes, dtypes, device, the regular shebang for foreach are not met. We fall back to slowpath!
- MORE IMPORTANTLY, we also do not fast path for int8 and int16 and bool, but that's really a skill issue on my end as I've hardcoded -INFINITY into the CUDA kernels, and -INFINITY is not defined for small ints. It'd be nice to know how to do this properly, but that work can also come later.
- This does NOT support empty Tensors in the list, because the original max op also does not support empty Tensors. ~I think this should be allowed though, and this PR may come later.~ I understand why this is not allowed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127187
Approved by: https://github.com/albanD
Summary:
This PR implements sliding window and updates "aten._flash_attention_forward/_flash_attention_backward" to expose the window_size_left and window_size_right arguments. With this kwarg added we can dispatch to the FAv2 impl if the necessary constraints are met.
These arguments will eventually be provided to "aten.sdpa_flash" but for now they are needed when called by xformers into their effort to directly use the Pytorch FAv2 impl instead of building their own.
Test Plan:
Use the default aten.sdpa_flash tests since we've added optional arguments set to the previous default value: -1, /*window_size_left*/
Using buck2 build --flagfile fbcode//mode/dev-nosan fbcode//caffe2/caffe2/fb/predictor/tests:inference_context_test
Differential Revision: D56938087
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126061
Approved by: https://github.com/drisspg, https://github.com/desertfire
> previous: Originally, the variables `new_eta` and `new_mu` would be constructed `len(grouped_mus)` times, but each of their values is the same and won't be changed. Therefore, it can be simplified using Python list multiplication, which only constructs one tensor.
- [X] Ill assumption that every param will have the same step.
- [x] DIfferent implementation between `foreach=Ture` and `foreach=False`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125440
Approved by: https://github.com/janeyx99
This commit introduces a meta function for the `channel_shuffle` operation, enabling PyTorch to perform shape inference and optimizations related to this operation without actual computation. The meta function assumes input shape (*, C, H, W) and validates that the number of channels (C) is divisible by the specified number of groups.
Fixes#122771
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123033
Approved by: https://github.com/ezyang, https://github.com/mikaylagawarecki
Given the following code/dynamo graph:
```
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
_print = torch.ops.aten._print('moo')
res = l_x_ + l_x_; l_x_ = None
_print_1 = torch.ops.aten._print('moo')
return (res,)
```
AOTAutograd will trace the following program, threading tokens from the inputs, through the effectful operator calls (torch.ops.aten._print), and as an output:
```
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[0]", arg1_1: "f32[2, 3]"):
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None
getitem: "f32[0]" = with_effects[0]; with_effects = None
add: "f32[2, 3]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2: "f32[0]" = with_effects_1[0]; with_effects_1 = None
return (getitem_2, add)
```
However when we get to inductor, since we want the inductor generated code to not have any token inputs/outputs for better readability, we want to modify the aten graph by removing the tokens from inputs, and creating them through `torch.ops.aten._make_dep_token`, and sinking them through the `torch.ops.aten._sink_tokens` operators.
This has to be done *after* the partitioner, otherwise the partitioner will add the make_token/sink_token operators to the backwards graph.
```
class <lambda>(torch.nn.Module):
def forward(self, arg1_1: "f32[2, 3]"):
_make_dep_token_default: "f32[0]" = torch.ops.aten._make_dep_token.default()
with_effects = torch._higher_order_ops.effects.with_effects(_make_dep_token_default, torch.ops.aten._print.default, 'moo'); _make_dep_token_default = None
getitem: "f32[0]" = with_effects[0]; with_effects = None
add: "f32[2, 3]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
getitem_2: "f32[0]" = with_effects_1[0]; with_effects_1 = None
_sink_tokens_default = torch.ops.aten._sink_tokens.default((getitem_2,)); getitem_2 = None
return (add,)
```
When doing inductor lowering, we convert `with_effects` calls to an `EffectfulKernel`, which just a `FallbackKernel` but with a pointer to previous effectful operator's call. During scheduling, we will create a `StarDep` between the EffectfulKernel and its previous EffectfulKernel so that they don't get reordered. The inductor generated python code looks like:
```
def call(args):
arg1_1, = args
args.clear()
assert_size_stride(arg1_1, (2, 3), (3, 1))
# Source Nodes: [_print], Original ATen: []
buf2 = aten._print.default('moo')
# Source Nodes: [_print_1], Original ATen: []
buf3 = aten._print.default('moo')
buf4 = empty_strided_cpu((2, 3), (3, 1), torch.float32)
cpp_fused_add_0(arg1_1, buf4)
del arg1_1
return (buf4, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122347
Approved by: https://github.com/bdhirsh
`linalg_eigvals_out` calls into a dispatch stub, so only supports CPU and CUDA
strided tensors but incorrectly claimed to be a composite op. `linalg_eigvals`
also shouldn't defer to the out variant inside a `CompositeImplicitAutograd` op
as not all types support out variants. Instead, I add a new helper
`_linalg_eigvals` which does the same thing in a non-composite operator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121142
Approved by: https://github.com/lezcano
**description**
Enable lowering of dynamic qlinear for X86Inductor. The pattern is `choose_qparams -> getitem -> q -> dq -> linear`. We only fuse `dq -> linear` and get `choose_qparams -> getitem -> q -> onednn.qlinear_pointwise`. So, we treat it as dynamic quantization of activation + static quantized linear.
The previous implementation of `onednn.qlinear_pointwise` is for the case where `x_scale` and `x_zp` are scalars. Since `choose_qparams` returns tensors, we added a variation `onednn.qlinear_pointwise.tensor` to support the case.
This feature is targeting PyTorch 2.3 release.
**Test plan**
```
python inductor/test_mkldnn_pattern_matcher.py -k test_dynamic_qlinear_cpu
python inductor/test_mkldnn_pattern_matcher.py -k test_dynamic_qlinear_qat_cpu
python inductor/test_cpu_cpp_wrapper.py -k test_dynamic_qlinear
```
**Performance before and after lowering `choose_qparam` to Inductor**
Before
- latency for shape (32, 32) = 0.151 ms
latency for shape (128, 128) = 0.153 ms
latency for shape (1024, 1024) = 0.247 ms
After
- latency for shape (32, 32) = 0.049 ms
- latency for shape (128, 128) = 0.052 ms
- latency for shape (1024, 1024) = 0.133 ms
Test method: A module with a single Linear layer, dynamic-quantize, lower to X86Inductor
Test env & config: Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz, single instance, single core, using Intel OpenMP and Tcmalloc
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120605
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jerryzh168
This PR is mostly just code movement to make the code review easier - AFAIK it should not change any functionality. The final goal is to remove the xfails for some of the test_fake opinfos for these ops. The opinfos are failing because the outputs can have mixed devices - we need to move them to fake_impls first before we can support mixed device returns.
This PR:
* Move the `_meta_registrations.py` implementations to `fake_impls.py`
* Change the function signature from taking explicit named variables to taking `{args, kwargs}` and normalizing them
* Wrap all the returned tensors in FakeTensors
Tests: relying on opinfos. I also checked `test_fake_*` for these tests (by removing x-fails and patching things until they passed) to verify general correctness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120682
Approved by: https://github.com/drisspg
The first try reused TensorListMetadata, which caused illegal memory access issues when there were too many tensors in the list. We just launch multiple kernels with a simpler version of the struct (to minimize kernels launched).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119927
Approved by: https://github.com/albanD
Meta registration wrongly assumes 4D inputs, while the underlying op allows 3D inputs for the `mha_varlen_fwd()` case.
Testing: I added `detach()`es so the NJT test `test_sdpa_compile()` won't fail for a view-related reason. It should pass now with this fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119812
Approved by: https://github.com/drisspg
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
This should fix remaining errors with Resize op in torchvision: https://github.com/pytorch/vision/actions/runs/7298953575?pr=8127
```
/opt/conda/envs/ci/lib/python3.8/site-packages/torch/nn/functional.py:4072: in interpolate
return torch._C._nn._upsample_bicubic2d_aa(input, output_size, align_corners, scale_factors)
E torch._dynamo.exc.TorchRuntimeError: Failed running call_function <function interpolate at 0x7f4443fe00d0>(*(FakeTensor(..., size=(1, s0, s1, s2)),), **{'size': [s4, floor(s3*s4/floor(s1*s3/s2))], 'mode': 'bicubic', 'align_corners': False, 'antialias': True}):
E aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:5567: SymIntArrayRef expected to contain only concrete integers
E
E from user code:
E File "/pytorch/vision/torchvision/transforms/v2/functional/_geometry.py", line 260, in resize_image
E image = interpolate(
E
E Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E
E
E You can suppress this exception and fall back to eager by setting:
E import torch._dynamo
E torch._dynamo.config.suppress_errors = True
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117347
Approved by: https://github.com/peterbell10
Summary:
This PR adds in support for passing in a alpha Tensor, which represents
a tensor of alpha values to fuse into the matmul.
```
cusparselt_sparse_mm = alpha A @ B + bias
```
This operation is necessary for quantization, where we would like to
fuse one of the dequant matmuls into the sparse op.
Test Plan:
```
python test/test_sparse_semi_structured -k alpha
```
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112056
Approved by: https://github.com/cpuhrsch
_cslt_sparse_mm + additional stride checking in test.
Summary:
This PR adds in meta registrations for _cslt_sparse_mm.
Based on the work @drisspg did
in #114370.
Additionally, it updates the tests by checking that the strides of the
spare result and the result returned by sparse+compile are the same, to
avoid errors like those found in
https://github.com/pytorch/pytorch/pull/114477.
Test Plan:
```
python test/test_sparse_semi_structred -k compile_cusparselt
python test/test_sparse_semi_structred -k compile_cutlass
```
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114685
Approved by: https://github.com/alexsamardzic, https://github.com/drisspg
# Summary
Improved Fix for Attention Mask Alignment Issue (#112577)
This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.
## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.
Should this be warn_once?
We only call expand, once on the aligned mask.
Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115
@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173
Approved by: https://github.com/danthe3rd
# Summary
Improved Fix for Attention Mask Alignment Issue (#112577)
This PR addresses Issue #112577 by refining the previously implemented fix, which was found to be incorrect and causes un-needed memory regressions. The update simplifies the approach to handling the alignment of the attention mask for mem eff attention.
## Changes
Alignment Check and Padding: Initially, the alignment of the attention mask is checked. If misalignment is detected, padding is applied, followed by slicing. During this process, a warning is raised to alert users.
Should this be warn_once?
We only call expand, once on the aligned mask.
Reference
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/fmha/cutlass.py#L115
@albanD, @mruberry, @jbschlosser, @walterddr, and @mikaylagawarecki.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114173
Approved by: https://github.com/danthe3rd
masked_scatter_backward was previously implemented as a
CompositeExplicitAutograd, which involved a decomp that calls
masked_select, and masked_select in general produces data-dependent
shapes that inductor doesn't support. But masked_scatter_backward
reshapes the return value of masked_select such that the end result has
a static shape again.
I have converted masked_scatter_backward into an aten op to avoid this
issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109642
Approved by: https://github.com/ezyang
ghstack dependencies: #108170
Testing out some new rules that are in beta, I think I will apply this one codebase wide once it's out of preview. Replaces the hack of using `[:]` to do copies of list with the proper copy method. More efficient and more readable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112990
Approved by: https://github.com/ezyang
- Extend `test_torch_dispatch_meta_outplace` to test torch ops that do not have an out parameter but have aten op overloads that have out parameters. Additionally, Python decompositions may register `OpOverloadPacket`'s so decompositions need to be tested to ensure all `OpOverloads` still function for the `Meta` key (e.g. if a python decomposition is registered for an aten op `aten.foo` with overloads `[default, out]`, the python function needs to support receiving out arguments)
- Add out parameter wrappers to python decomps for aten ops that have out overloads
CC. @ezyang @albanD @lezcano
Fixes#107713
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107707
Approved by: https://github.com/lezcano
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109142
Approved by: https://github.com/yanboliang
ghstack dependencies: #109663, #108894, #108917
aten.softmax will generate a different decomposition for fp16/bf16 and fp32 because when invoked in lower precision it will upcast the inputs to fp32 and then downcast after. This has been causing us to miss bf16 patterns. For example, Camembert improves 20% with this PR (as do I'm sure many other models).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109142
Approved by: https://github.com/yanboliang
ghstack dependencies: #108894, #108917
This fixes numerous tests which were xfailing. For instance, the
`_segment_reduce.lengths` OpInfo test, which was previously relying on
the fallback kernel to determine the shape of the meta tensor. The
fallback kernel would fail with
segment_reduce(): Expected all rows of lengths along axis to sum to data.size(lengths.dim()-1) when !unsafe.
as it was trying to read the values of a meta tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109359
Approved by: https://github.com/ezyang
The sample inputs is a bit involved because there are a lot of
shenanigans in the derivative formula. Check comments.
This is exercised in vdd, internal test `buck2 run '@fbcode//mode/opt' fbcode//pytorch/benchmark/fb/test_gpu:run_test_gpu -- 'pytorch.benchmark.fb.test_gpu.test_gpu.TestBenchmarkFbGpu.test_train_blue_reels_vdd_v3_inductor_speedup'`
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109211
Approved by: https://github.com/albanD, https://github.com/zou3519
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985
### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao
### Changes Made
The majority of the changes in this pull request involve:
- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.
### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)
There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.
### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108
### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes
### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985
### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao
### Changes Made
The majority of the changes in this pull request involve:
- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.
### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)
There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.
### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108
### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes
### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985
### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao
### Changes Made
The majority of the changes in this pull request involve:
- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.
### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)
There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.
### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108
### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes
### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
We allow registering decomps for HigherOrderOp via the existing decomp
mechanisms:
- I refactored those APIs to accept torch._ops.OperatorBase, which is the base
class for torch.ops.HigherOrderOperator and torch.ops.OpOverload
- HigherOrderOps must directly call maybe_handle_decomp in their
ProxyTorchDispatchMode handling in order to resolve decompositions. We
can change this in the future so that they do not need to do this.
Next, we add an inductor decomp for out_dtype. This decomp shouldn't be
generally available because we want to preserve out_dtype to the backend
for other use cases (i.e. executorch).
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108080
Approved by: https://github.com/HDCharles
Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
1. add a python meta registration, to fix an issue with the forward pass. The problem was that previously, the C++ meta registration calls [numel()](7b14a14e27/aten/src/ATen/native/TensorAdvancedIndexing.cpp (L329)) which fails (LMK if it's better to fix the C++ implementation to not do this check)
2. Modify the backward to fix an issue in the backward. The backward is not a custom op - it's a custom manual backward implementation. In particular, there's some situations that don't support double backward; the check for whether double backward is allowed requires a .item() call. To fix the meta/fake tensor case, this PR will avoid setting the double backward error only if `GradMode::is_enabled()` - which shouldn't be turned on in PT2.
3. Update skips.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106429
Approved by: https://github.com/zou3519
# Summary
### Review Points
- Automatically pad tensors to create aligned masks when seqlen_kv is not multiple of 16. This will cause memory spike ~ 2 * attn_mask size which could in theory be big. At appears though that doing this + mem_eff is faster than no_pad + math. SO seems to be worth it
- Using expand to view the attn_mask in 4d. This is a little different to how we enforce q,k,v to be viewed in 4d prior to calling. Also not supprint b*n_heads, seq_lenq, seq_lenkv case.
- Should enable, #96099
### Profiling
I ran a bunch of comparisons between sdpa.MATH and sdp.MemEffAttention. I added a attn_bias of shape (1, 1, seqlen_q, seqln_k). For these experiments seqlen_q == seqlen_k. These were all ran on an a100 80gb gpu.
Configs:
```
# Run a bunch of experiments
batch_sizes = [8, 16, 32]
num_heads = [16, 32]
max_seq_lens = [15, 64, 128, 512, 555, 1024]
embed_dims = [32, 64, 128]
dtypes = [torch.float16, torch.bfloat16, torch.float32]
pad_percentages = [None]
backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
run_backward = True
attn_mask = True
```
The function calls `sdpa(input**).sum().backward()`.
I calculated the geomean speedup of the efficient attention path of the math path for all these configs:
`Geomean Speedup: 1.977`
An example comparision with batchsize = 8, num_heads = 32, embed_dim = 64, and dtype = torch.float16:

This was done using the current state of the branch where we force alignment of mask when the last dim is not divisible by 16, which shows up in seq_len = 15 and 555 case.
The full data can be found here:
[attn_mask_sweep.csv](https://github.com/pytorch/pytorch/files/11962399/attn_mask_sweep.csv)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104310
Approved by: https://github.com/cpuhrsch