mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Meta registration wrongly assumes 4D inputs, while the underlying op allows 3D inputs for the `mha_varlen_fwd()` case. Testing: I added `detach()`es so the NJT test `test_sdpa_compile()` won't fail for a view-related reason. It should pass now with this fix. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119812 Approved by: https://github.com/drisspg |
||
|---|---|---|
| .. | ||
| _internal | ||
| __init__.py | ||