Commit Graph

3498 Commits

Author SHA1 Message Date
zeshengzong
a000c7e6d2 Add hint message for pack_padded_sequence (#146747)
Fixes #144207

Add truncate hint message in docs [torch.nn.utils.rnn.pack_padded_sequence](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_padded_sequence.html)

## Test Result

![image](https://github.com/user-attachments/assets/46258f36-f6c7-4f11-9213-8513e52a9001)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146747
Approved by: https://github.com/mikaylagawarecki
2025-02-20 06:27:07 +00:00
Aaron Orenstein
db4ce78d46 PEP585: More UP006 fixes (#146392)
This should be the final PR before we can enable RUFF UP006.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146392
Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
2025-02-20 06:18:13 +00:00
Simon Fan
ed83b0b70b [ddp] decouple python reducer from compilation mode (#147123)
Current implementation reads as: we will only actually use the "python_reducer" config if the DDP forward is compiled. Otherwise, we will silently fallback to C++ reducer + no DDPOptimizer.
I'm changing this behavior to always use the python reducer if the config is specified.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147123
Approved by: https://github.com/fegin
2025-02-19 15:51:40 +00:00
lzhang2
b16ae97ad0 Generalize mixed precision in DDP (#146808)
**Motivation:**

1. Generalize mixed precision in DDP.
2. Enable `SyncBatchNorm` for XPU device.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146808
Approved by: https://github.com/guangyey, https://github.com/gujinghui, https://github.com/wconstab
2025-02-16 11:59:40 +00:00
zeshengzong
4a545eb85d Fix torch.nn.functional.one_hot param num_classes optional description (#146470)
`torch.nn.functional.one_hot` [document](https://pytorch.org/docs/stable/generated/torch.nn.functional.one_hot.html) describe param `num_classes` not optional, but user can call method without pass it.

![image](https://github.com/user-attachments/assets/4e6d4feb-691f-451f-95b5-4ac11bac7bc2)

```python
>>> import torch
>>> a = torch.arange(0, 5) % 3  # [0,1,2,0,1]
>>> torch.nn.functional.one_hot(a)
tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 0],
        [0, 1, 0]])

```

`num_classes` has default value -1

93d98aca31/aten/src/ATen/native/native_functions.yaml (L6154-L6157)

## Test Result

![image](https://github.com/user-attachments/assets/2c7203b7-6226-4ebc-84c8-cbf912fc48e2)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146470
Approved by: https://github.com/albanD
2025-02-06 07:48:05 +00:00
Aaron Gokaslan
7f65a20884 [BE]: Enable ruff SLOT checks (#146276)
This enables a check that which a class which only inherits from immutable classes like str, tuple, and NamedTuple, also defined `__slots__` so they don't allocate memory unnecessarily. This also ensure contributors think about how they define their classes with subclass NamedTuples and str, of which we have many in our codebase

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146276
Approved by: https://github.com/aorenste
2025-02-04 19:18:23 +00:00
Aaron Gokaslan
292af3cc89 [BE][Ez]: ISC001 Auto concatenate implicit one line strings (#146408)
Apply ruff rule about implicit string concatenation, this autofixes strings that are all the same type and on the same line. These lines are broken up likely as the result of autoformatters in the past. All fixes are automated using the autofixes in ISC001.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146408
Approved by: https://github.com/justinchuby, https://github.com/janeyx99
2025-02-04 19:07:04 +00:00
Sahdev Zala
f97307f463 [Docs] Add clarification for target types in CrossEntropyLoss doc (#145444)
CrossEntropyLoss function requires that target for class indices are provided as a long and class probabilities are provided as a float datatype.

The CrossEntropyLoss function distinguish the two scenarios (indices and probabilities) by comparing the shapes. When input and target shapes are the same it’s a case for probabilities otherwise it will be used as a class index as already covered in the doc. The related code is here,
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossNLL.cpp#L624

I think the current documentation is great but seems like it can confuse users about types as reported in the issues so this PR adds a bit more clarification.

Fixes #137188

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145444
Approved by: https://github.com/mikaylagawarecki
2025-02-01 18:55:58 +00:00
Alexander Kurakin
35f113e2a0 torch/nn/utils/rnn.py: docs: improvements (#138628)
Fix constants highlighting in generated documentation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138628
Approved by: https://github.com/mikaylagawarecki
2025-02-01 00:10:30 +00:00
chilli
2d5d022594 Fix a number of flexattention issues (cse, cudagraph, etc.) (#145059)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145059
Approved by: https://github.com/Skylion007, https://github.com/drisspg
2025-01-29 20:27:39 +00:00
Aaron Orenstein
7178b827d7 PEP585: Missed conversions (#145342)
Differential Revision: [D68785969](https://our.internmc.facebook.com/intern/diff/D68785969)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145342
Approved by: https://github.com/bobrenjc93
2025-01-29 05:24:36 +00:00
PyTorch MergeBot
09ae69a364 Revert "Fix type annotation of Linear.bias (#142326)"
This reverts commit 81e370fc6b.

Reverted https://github.com/pytorch/pytorch/pull/142326 on behalf of https://github.com/malfet due to This introduced a graph break and regressed inductor tests, see 73622fc5fa/1 ([comment](https://github.com/pytorch/pytorch/pull/142326#issuecomment-2614196349))
2025-01-26 03:41:00 +00:00
zeshengzong
5b988ac4fa [Easy] Replace paper description with link to make a concise description. (#145031)
Description in [Transformer,](https://pytorch.org/docs/main/generated/torch.nn.Transformer.html), [TransformerEncoderLayer](https://pytorch.org/docs/main/generated/torch.nn.TransformerEncoderLayer.html), [TransformerDecoderLayer](https://pytorch.org/docs/main/generated/torch.nn.TransformerDecoderLayer.html) pages contain authors and paper details seems redundant for users who want to know how to use it, replace with a link to paper content, users can go to the paper detail if they want to learn more.

**Test Result**

**Before**
![image](https://github.com/user-attachments/assets/678402b1-e759-402c-b56b-e24f63dc8490)
![image](https://github.com/user-attachments/assets/ca191734-f2ce-493f-bf34-2d7046a9868f)
![image](https://github.com/user-attachments/assets/10f55083-6eb6-4b1c-9a77-579f0c4c56ed)

**After**
![image](https://github.com/user-attachments/assets/020f81ca-d89b-47d1-a7a9-cae1893df968)
![image](https://github.com/user-attachments/assets/5b9b34df-b892-4d71-8cdb-df18380b2744)
![image](https://github.com/user-attachments/assets/b3348da2-842a-4037-bad3-f23687503cf8)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145031
Approved by: https://github.com/mikaylagawarecki
2025-01-24 23:01:02 +00:00
Fabian Keller
81e370fc6b Fix type annotation of Linear.bias (#142326)
Currently the `bias` attribute of `torch.nn.Linear` (and `Bilinear`) is typed incorrectly, because it relies on the implicit `Module.__getattr__` which types it as `Tensor | Module`. This has two issues:

- It hides the fact that `bias` is optional, and can be `None`, which in turn can hide actual bugs on user side.
- It blurs the type due to having `Module` in the union, which can require unnecessary `isistance(linear.bias, Tensor)` on user side.

This PR types the `bias` attribute explicitly to fix these issues.

CC @ezyang @Skylion007

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142326
Approved by: https://github.com/ezyang
2025-01-24 22:43:52 +00:00
drisspg
c6707734de Enable non power of 2 head_dim for FlexAttention (#133495)
# Summary
- Adds support for non-power of 2 headdim by launching blocks w/ head_dim rounded to the next valid power.
- Other option I considered was building up the final dot_products with smaller blocks (this would probably work but for sake of code complexity going with this option for now)

### Corollary
We had a bug in our backwards kernel where we were using index_k instead of index_v. This should have shown up for the qk_head_dim != v_head_dim cases..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133495
Approved by: https://github.com/Chillee
2025-01-23 17:05:38 +00:00
Aaron Orenstein
0afd335174 PEP585 update - torch/nn torch/optim torch/package torch/profiler torch/serialization torch/sparse torch/xpu (#145175)
See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145175
Approved by: https://github.com/bobrenjc93
2025-01-21 16:57:27 +00:00
PyTorch MergeBot
5fd881a5b6 Revert "PEP585 update - torch/nn torch/optim torch/package torch/profiler torch/serialization torch/sparse torch/xpu (#145175)"
This reverts commit 54a00af2c6.

Reverted https://github.com/pytorch/pytorch/pull/145175 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to break some trunk tests ([comment](https://github.com/pytorch/pytorch/pull/145175#issuecomment-2603418267))
2025-01-21 00:49:55 +00:00
Aaron Orenstein
54a00af2c6 PEP585 update - torch/nn torch/optim torch/package torch/profiler torch/serialization torch/sparse torch/xpu (#145175)
See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145175
Approved by: https://github.com/bobrenjc93
2025-01-20 22:32:59 +00:00
Mario Vasilev
49bdc418be Add strict kwarg to nn.Module.set_submodule and fix bug for non dot delineated strings (#143455)
Before fixing set_submodule, it used to create leaf modules when the target was not a dot-delimited string. After the fix it will not create a new attribute if target is a non-dot-delimited string. If you want to create leaf nodes of `nn.Module` parent nodes, you can use `replace_or_create_new_leaf_module`.

Fixes https://github.com/pytorch/pytorch/issues/143441

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143455
Approved by: https://github.com/mikaylagawarecki
2025-01-16 05:06:33 +00:00
Boyuan Feng
069419569d [PagedAttention] Support different input position for each batch index (#144693)
In LLM inference, each request usually has different prefill length, leading to different input position for each batch index. This PR adds such support for paged attention.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144693
Approved by: https://github.com/drisspg
2025-01-15 18:03:52 +00:00
Aaron Orenstein
d782e46a36 [BE] typing for decorators - library (#138969)
Test Plan: unit tests

Differential Revision: D62302678

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138969
Approved by: https://github.com/zou3519
2025-01-15 17:08:55 +00:00
cyy
d87aad6877 [5/N] Apply Ruff fixes and pyupgrade to Python 3.9 (#144205)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144205
Approved by: https://github.com/albanD
2025-01-15 04:00:47 +00:00
Aaron Gokaslan
91dbd7b75c [BE]: Improve typing inference with TypeIs (#144682)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144682
Approved by: https://github.com/albanD

Co-authored-by: Aaron Orenstein <aorenste@meta.com>
2025-01-13 21:14:31 +00:00
bobrenjc93
f93d786f73 remove allow-untyped-defs from torch/nn/parameter.pyi (#144654)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144654
Approved by: https://github.com/Skylion007
2025-01-13 19:02:31 +00:00
Alexander Kurakin
68dad26b95 torch/nn/modules/linear.py: docs: improvements (#138484)
torch/nn/modules/linear.py: docs: improvements
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138484
Approved by: https://github.com/mikaylagawarecki
2025-01-10 20:03:43 +00:00
Aaron Gokaslan
307ca094c9 [BE]: Remove redundant contiguous copy in flex attention (#144467)
Removes a redundant potential copy, instead use memory_format kwarg to fuse both operations into a single copy.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144467
Approved by: https://github.com/awgu
2025-01-09 18:30:09 +00:00
Mikayla Gawarecki
b8f383107e Link to transformer tutorial in transformer docs (#144425)
<img width="1045" alt="Screenshot 2025-01-08 at 4 50 20 PM" src="https://github.com/user-attachments/assets/05adfecb-8a23-4c48-9a2c-50c5b3f886b0" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144425
Approved by: https://github.com/albanD
2025-01-09 17:42:09 +00:00
bobrenjc93
168c2cb3f3 remove allow-untyped-defs from torch/nn/utils/_deprecation_utils.py (#144231)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144231
Approved by: https://github.com/albanD
2025-01-07 02:22:22 +00:00
Guilherme Leobas
e222dd5d25 Rewrite _reparametrize_module to use contextmanager (#138203)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138203
Approved by: https://github.com/zou3519
ghstack dependencies: #136033, #140604
2025-01-06 16:56:22 +00:00
bobrenjc93
52742b07c5 remove allow-untyped-defs from nn/utils/_deprecation_utils.py (#144136)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144136
Approved by: https://github.com/aorenste
2025-01-03 23:44:14 +00:00
PyTorch MergeBot
2409b49a33 Revert "Rewrite _reparametrize_module to use contextmanager (#138203)"
This reverts commit 7bf3b7cdc5.

Reverted https://github.com/pytorch/pytorch/pull/138203 on behalf of https://github.com/guilhermeleobas due to breaking one of the benchmarks (moco) ([comment](https://github.com/pytorch/pytorch/pull/138203#issuecomment-2569634001))
2025-01-03 18:17:32 +00:00
Joel Schlosser
228b228449 Fix batch-specific attention mod for NJT + Flex (#143866)
Fixes #143788
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143866
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
2024-12-27 20:51:41 +00:00
Guilherme Leobas
7bf3b7cdc5 Rewrite _reparametrize_module to use contextmanager (#138203)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138203
Approved by: https://github.com/zou3519
ghstack dependencies: #136033, #140604
2024-12-20 12:02:27 +00:00
bobrenjc93
03991798ca remove allow-untyped-defs for torch/nn/parallel/__init__.py (#143437)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143437
Approved by: https://github.com/oulgen
2024-12-18 08:50:37 +00:00
drisspg
744a303dee [FlexAttention] Optimzing learned bias perf to dq calc (#142281)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142281
Approved by: https://github.com/Chillee
2024-12-15 21:44:32 +00:00
Tristan Rice
688f44824b DistributedDataParallel: add init_sync option to control collectives during initialization (#142824)
This controls whether or not we run collectives during the DDP init function. This makes it easier to use fault tolerant ProcessGroup implementations that may not be starting at the same time.

torchft uses a dummy process group and a comm hook to get around these checks. With this change torchft can use the normal ProcessGroup API via the stock comm hook.

https://github.com/pytorch-labs/torchft/blob/main/torchft/ddp.py#L50-L59

Test plan:

```
pytest test/distributed/test_c10d_pypg.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142824
Approved by: https://github.com/wconstab, https://github.com/fegin, https://github.com/H-Huang
2024-12-11 20:28:38 +00:00
Jane Xu
fd65bd755d [BE] replace incorrect .. note:: invocations (#142868)
Something I've noticed is that a lot of the distributed sites don't render on our docs at all, but if they ever do, the notes will render properly now 😛

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142868
Approved by: https://github.com/albanD
2024-12-11 19:58:18 +00:00
jianan-gu
d51e6fa7f6 [inductor][cpp] Add FlexAttention support for CPU inference (#141453)
This PR brings the FlexAttention inference support for the inductor backend in torch.compile (support precisions: bf16 and fp32) on CPUs.

Based on the existing CPP template, this PR extends and implements a FlexAttention CPP template to support broad attention variants, and meanwhile brings optimized performance on CPUs.

With this, users can transparently extend their Flex Attention usages to CPUs with good and common support from torch.compile, both functionality and performance.

For UT tests, in this PR, we include partial critical tests for CPUs as the following (conduct inference tests):
```
pytest test/inductor/test_flex_attention.py
`TestFlexAttention`
#common functions:
run_test
preprocess_paged_attention
run_paged_attention
run_test_with_paged_attention
run_test_with_call
run_dynamic_test
run_automatic_dynamic_test

#test functions:
test_builtin_score_mods
test_builtin_score_mods_automatic_dynamic
test_builtin_score_mods_different_seqlen
test_builtin_score_mods_different_block_size
test_kv_batch_broadcast
test_GQA
test_cpu_error_message_return_lse
test_validate_cpu_dtype_error_message

`TestPagedAttention`
#test function:
test_paged_builtin_score_mods
```
For the rest UTs in `test/inductor/test_flex_attention.py ` and `test/inductor/test_flex_decoding.py`, due to bigger lines of changes (1500+ LOC) that make this PR hard to review, will submit another PR specific for CPU device UTs enabling and refactor.

Besides, more optimizations are also planned in follow up PRs, including:

- Block sparse computation
- Flash decoding tuning

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141453
Approved by: https://github.com/drisspg, https://github.com/leslie-fang-intel

Co-authored-by: Wu, Chunyuan <chunyuan.wu@intel.com>
2024-12-10 11:11:09 +00:00
Nikita Shulga
3291b0a013 [DataParallel] Skip for MPS device (#142448)
As `torch._C._scatter` is only defined for CUDA/ROCm (and may be XPU?)

This is a regression introduced by https://github.com/pytorch/pytorch/pull/141098 that went unnoticed due to https://github.com/pytorch/pytorch/issues/142206

Test plan:
```
python test_autograd.py -v -k test_dataparallel_saved_tensors_hooks
```

Before this change it failed with
```
ERROR: test_dataparallel_saved_tensors_hooks (__main__.TestMultithreadAutograd.test_dataparallel_saved_tensors_hooks)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/malfet/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 3108, in wrapper
    method(*args, **kwargs)
    ~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/malfet/git/pytorch/pytorch/test/test_autograd.py", line 13074, in test_dataparallel_saved_tensors_hooks
    model = torch.nn.DataParallel(Model())
  File "/Users/malfet/git/pytorch/pytorch/torch/nn/parallel/data_parallel.py", line 153, in __init__
    raise RuntimeError("no available devices were found")
RuntimeError: no available devices were found
```

After this change it passes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142448
Approved by: https://github.com/kit1980
2024-12-10 02:49:23 +00:00
PyTorch MergeBot
beeffe77e4 Revert "[inductor][cpp] Add FlexAttention support for CPU inference (#141453)"
This reverts commit db379ed1ad.

Reverted https://github.com/pytorch/pytorch/pull/141453 on behalf of https://github.com/malfet due to This breaks tests on platforms compiled without MKLDNN, namely MacOS, see https://github.com/pytorch/pytorch/actions/runs/12245441371/job/34159967794 ([comment](https://github.com/pytorch/pytorch/pull/141453#issuecomment-2529710573))
2024-12-09 22:57:59 +00:00
jianan-gu
db379ed1ad [inductor][cpp] Add FlexAttention support for CPU inference (#141453)
This PR brings the FlexAttention inference support for the inductor backend in torch.compile (support precisions: bf16 and fp32) on CPUs.

Based on the existing CPP template, this PR extends and implements a FlexAttention CPP template to support broad attention variants, and meanwhile brings optimized performance on CPUs.

With this, users can transparently extend their Flex Attention usages to CPUs with good and common support from torch.compile, both functionality and performance.

For UT tests, in this PR, we include partial critical tests for CPUs as the following (conduct inference tests):
```
pytest test/inductor/test_flex_attention.py
`TestFlexAttention`
#common functions:
run_test
preprocess_paged_attention
run_paged_attention
run_test_with_paged_attention
run_test_with_call
run_dynamic_test
run_automatic_dynamic_test

#test functions:
test_builtin_score_mods
test_builtin_score_mods_automatic_dynamic
test_builtin_score_mods_different_seqlen
test_builtin_score_mods_different_block_size
test_kv_batch_broadcast
test_GQA
test_cpu_error_message_return_lse
test_validate_cpu_dtype_error_message

`TestPagedAttention`
#test function:
test_paged_builtin_score_mods
```
For the rest UTs in `test/inductor/test_flex_attention.py ` and `test/inductor/test_flex_decoding.py`, due to bigger lines of changes (1500+ LOC) that make this PR hard to review, will submit another PR specific for CPU device UTs enabling and refactor.

Besides, more optimizations are also planned in follow up PRs, including:

- Block sparse computation
- Flash decoding tuning

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141453
Approved by: https://github.com/drisspg, https://github.com/leslie-fang-intel

Co-authored-by: Wu, Chunyuan <chunyuan.wu@intel.com>
2024-12-09 18:44:39 +00:00
Fabian Keller
5e8e1d725a Remove some unused type ignores (round 1) (#142325)
Over time, a large number of the existing type ignores have become irrelevant/unused/dead as a result of improvements in annotations and type checking.

Having these `# type: ignore` linger around is not ideal for two reasons:

- They are verbose/ugly syntatically.
- They could hide genuine bugs in the future, if a refactoring would actually introduce a bug but it gets hidden by the ignore.

I'm counting over 1500 unused ignores already. This is a first PR that removes some of them. Note that I haven't touched type ignores that looked "conditional" like the import challenge mentioned in https://github.com/pytorch/pytorch/pull/60006#issuecomment-2480604728. I will address these at a later point, and eventually would enable `warn_unused_ignores = True` in the mypy configuration as discussed in that comment to prevent accumulating more dead ignores going forward.

This PR should have no effect on runtime at all.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142325
Approved by: https://github.com/Skylion007, https://github.com/janeyx99
2024-12-09 18:23:46 +00:00
PyTorch MergeBot
7101dcfb98 Revert "[inductor][cpp] Add FlexAttention support for CPU inference (#141453)"
This reverts commit 7edbde3334.

Reverted https://github.com/pytorch/pytorch/pull/141453 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think it is failing periodic NO_AVX2 ([comment](https://github.com/pytorch/pytorch/pull/141453#issuecomment-2527377475))
2024-12-09 09:26:20 +00:00
Xuehai Pan
e1196dfe51 Deprecate torch._utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-12-08 22:55:36 +00:00
jianan-gu
7edbde3334 [inductor][cpp] Add FlexAttention support for CPU inference (#141453)
This PR brings the FlexAttention inference support for the inductor backend in torch.compile (support precisions: bf16 and fp32) on CPUs.

Based on the existing CPP template, this PR extends and implements a FlexAttention CPP template to support broad attention variants, and meanwhile brings optimized performance on CPUs.

With this, users can transparently extend their Flex Attention usages to CPUs with good and common support from torch.compile, both functionality and performance.

For UT tests, in this PR, we include partial critical tests for CPUs as the following (conduct inference tests):
```
pytest test/inductor/test_flex_attention.py
`TestFlexAttention`
#common functions:
run_test
preprocess_paged_attention
run_paged_attention
run_test_with_paged_attention
run_test_with_call
run_dynamic_test
run_automatic_dynamic_test

#test functions:
test_builtin_score_mods
test_builtin_score_mods_automatic_dynamic
test_builtin_score_mods_different_seqlen
test_builtin_score_mods_different_block_size
test_kv_batch_broadcast
test_GQA
test_cpu_error_message_return_lse
test_validate_cpu_dtype_error_message

`TestPagedAttention`
#test function:
test_paged_builtin_score_mods
```
For the rest UTs in `test/inductor/test_flex_attention.py ` and `test/inductor/test_flex_decoding.py`, due to bigger lines of changes (1500+ LOC) that make this PR hard to review, will submit another PR specific for CPU device UTs enabling and refactor.

Besides, more optimizations are also planned in follow up PRs, including:

- Block sparse computation
- Flash decoding tuning

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141453
Approved by: https://github.com/jgong5, https://github.com/drisspg, https://github.com/leslie-fang-intel

Co-authored-by: Wu, Chunyuan <chunyuan.wu@intel.com>
2024-12-08 07:57:21 +00:00
eqy
8fc6d3a5d8 [SDPA] Allow user-specified priority order with context manager (#140467)
TODO: docs changes?
For better debuggability of issues like https://github.com/pytorch/pytorch/issues/139298

Better testing, current sketch:

``` Python
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel

q = torch.randn(64, 1024, 8, 64, dtype=torch.half, device='cuda')
print(torch._C._get_sdp_priority_order())

orders = [[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION],
          [SDPBackend.MATH, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION],
          [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]]
import time
times = list()
for order in orders:
    print(order)
    with sdpa_kernel(order, set_priority=True):
        scaled_dot_product_attention(q, q, q)
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    with sdpa_kernel(order, set_priority=True):
        scaled_dot_product_attention(q, q, q)
    torch.cuda.synchronize()
    t1 = time.perf_counter()
    times.append(t1 - t0)
print(times)
assert times[0] < times[1]
assert times[0] > times[2]
assert times[1] > times[2]
print(torch._C._get_sdp_priority_order())
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140467
Approved by: https://github.com/drisspg
2024-12-06 07:56:35 +00:00
Angela Yi
a9d84875a9 Fix mha torch._check in jit tracing (#142059)
Test Plan: `buck2 run @//mode/dev-nosan //mobile-vision/d2go/projects_oss/detr:tests -- -r test_detr_fbnet_export`

Differential Revision: D66769339

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142059
Approved by: https://github.com/ezyang
2024-12-05 18:38:17 +00:00
Andreas Säuberli
d24b147520 Update dead reference link for triplet margin loss (#142071)
The current link for _Learning local feature descriptors with triplets and shallow convolutional neural networks_ (https://www.bmva.org/bmvc/2016/papers/paper119/index.html) is dead (404). The paper is archived here: https://bmva-archive.org.uk/bmvc/2016/papers/paper119/index.html
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142071
Approved by: https://github.com/albanD
2024-12-05 15:01:10 +00:00
UV
0318589e87 Changed 'standard-deviation' to 'variance' in GroupNorm documentation (#141982)
Fixes #141315

Updated the GroupNorm documentation to replace 'standard-deviation' with 'variance' to accurately reflect the calculation
method.

@pytorchbot label "topic: not user facing"
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141982
Approved by: https://github.com/mikaylagawarecki
2024-12-04 22:49:45 +00:00
angelayi
80705d3abf Convert assert to torch._check in MHA (#141918)
Fixes https://github.com/pytorch/pytorch/issues/139610
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141918
Approved by: https://github.com/ezyang
2024-12-03 21:58:02 +00:00