pytorch/torch/nested
Joel Schlosser 31e59766e7 Fix meta registration for _flash_attention_forward() (#119812)
Meta registration wrongly assumes 4D inputs, while the underlying op allows 3D inputs for the `mha_varlen_fwd()` case.
Testing: I added `detach()`es so the NJT test `test_sdpa_compile()` won't fail for a view-related reason. It should pass now with this fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119812
Approved by: https://github.com/drisspg
2024-02-14 02:38:53 +00:00
..
_internal Fix meta registration for _flash_attention_forward() (#119812) 2024-02-14 02:38:53 +00:00
__init__.py Fix return type hint for list types (#118238) 2024-01-25 23:35:20 +00:00