Commit Graph

26 Commits

Author SHA1 Message Date
Angel Li
3a4140bf8e [FlexAttention] fixing learnable bias assertion error in inductor (#161170)
Users encountered unexpected behaviour when using FlexAttention with learnable biases, including assertion errors (#157677)

We traced the root cause to the registration of subgraph buffers—this caused inconsistencies in the naming and ultimately incorrect retrieval later on. This problem only arose if the model was compiled as a whole (ie using @torch.compile) since only then would there be naming conflicts.

In this PR, we register the buffers with the base graph to solve this issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161170
Approved by: https://github.com/drisspg
2025-08-23 06:24:22 +00:00
Yu, Guangye
c68af9af1b Fix XPU CI UT test_circular_dependencies (#158189)
# Motivation
fix https://github.com/pytorch/pytorch/issues/110040

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158189
Approved by: https://github.com/Skylion007, https://github.com/cyyever
2025-07-13 09:30:57 +00:00
Jason Ansel
06604c4ec1 [inductor] Refactor op handlers part 5 (#146257)
This makes OpHandler just a normal class using inheritance, and removes typing workarounds needed because it wasn't

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146257
Approved by: https://github.com/shunting314
ghstack dependencies: #146252, #146254, #146255
2025-02-08 18:00:30 +00:00
Jason Ansel
71498aeae3 [inductor] Refactor op handlers part 2 (#146252)
This replaces the `__getattr__()` pattern used in (some) OpHandlers with a `DefaultHandler` class that has an implementation of every op that calls `self._default()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146252
Approved by: https://github.com/yanboliang
2025-02-08 18:00:00 +00:00
PyTorch MergeBot
e0cf519ade Revert "[inductor] Refactor op handlers part 2 (#146252)"
This reverts commit 13f0436abd.

Reverted https://github.com/pytorch/pytorch/pull/146252 on behalf of https://github.com/atalman due to Sorry need to revert, failing internally ([comment](https://github.com/pytorch/pytorch/pull/146252#issuecomment-2638305417))
2025-02-06 00:04:04 +00:00
PyTorch MergeBot
49effa0deb Revert "[inductor] Refactor op handlers part 5 (#146257)"
This reverts commit d3dd3eeb7f.

Reverted https://github.com/pytorch/pytorch/pull/146257 on behalf of https://github.com/atalman due to Sorry need to revert https://github.com/pytorch/pytorch/pull/146252 ([comment](https://github.com/pytorch/pytorch/pull/146257#issuecomment-2638251994))
2025-02-05 23:20:38 +00:00
Jason Ansel
d3dd3eeb7f [inductor] Refactor op handlers part 5 (#146257)
This makes OpHandler just a normal class using inheritance, and removes typing workarounds needed because it wasn't

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146257
Approved by: https://github.com/shunting314
ghstack dependencies: #146225, #146226, #146235, #146252, #146254, #146255
2025-02-04 23:36:25 +00:00
Jason Ansel
13f0436abd [inductor] Refactor op handlers part 2 (#146252)
This replaces the `__getattr__()` pattern used in (some) OpHandlers with a `DefaultHandler` class that has an implementation of every op that calls `self._default()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146252
Approved by: https://github.com/yanboliang
ghstack dependencies: #146225, #146226, #146235
2025-02-04 23:36:01 +00:00
Aaron Orenstein
bac62341eb PEP585 update - torch/_inductor (#145198)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145198
Approved by: https://github.com/bobrenjc93
2025-01-21 21:04:33 +00:00
bobrenjc93
a3ab27b8e0 Migrate from Tuple -> tuple in torch/_inductor (#144264)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144264
Approved by: https://github.com/eellison
2025-01-07 03:27:27 +00:00
Tom Ritchford
da67a6a7bb [inductor] Replace set by OrderedSet (#138466)
Uses the set_linter from https://github.com/pytorch/pytorch/pull/138454
and considerable manual editing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138466
Approved by: https://github.com/eellison
2024-12-13 16:08:45 +00:00
drisspg
91f7c547ec [FlexAttention] add support for learnable biases in Inductor (#137452)
# Summary

The follow up PR to: https://github.com/pytorch/pytorch/pull/137526.  In this pr, we actually update the lowerings for the flex_attention backwards kernel to generate fused backward gradient calculations for any captured buffers that require grads.

We are doing this using tl.atomic_add to scatter the correct gradients into zeroed out buffer for any captured buffers that required grads. Added many test cases and found.  Along the way found some masking bugs.

There are likely some performance cliffs here, specifically with D-types and on different GPUs. Planned to do this in a follow-up and profile the current strategy. We are explicitly choosing reduced memory over increased performance right now.

By using atomics, we do not need to realize a full attention scores matrix. However, this comes with two downsides. One, this is potentially slower in some cases, and two, the gradient calculation for any captured buffers is non-deterministic.

## Worked Example

Lets do the case where you are reading from one bias that doesn't require grad and using this to index into another that does.

ScoreMod:
```Python
bias = torch.randn(
    params.seq_length,
    device=self.device,
    dtype=params.dtype,
    requires_grad=True,
)

offset = torch.randint(
    0,
    params.seq_length,
    (params.seq_length,),
    device=self.device,
)

def score_mod(score, b, h, q_idx, kv_idx):
    return score + bias[offset[q_idx]]

```

I am removing all but the new subgraph injected into the backwards:

``` Python
    dsT = pT * (dpT - Di[None, :])
    # ~~~~~~~~~~~~~~~~~~~ Apply joint modification  ~~~~~~~~~~~~~~~~~~~
    grad_scores = (dsT)

    # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
    idx_b = off_z
    idx_h = off_hq
    idx_m = m
    idx_n = n
    scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN
    tmp4 = (dsT).to(tl.float32)
    tl.atomic_add(out_ptr1 + (tl.broadcast_to(tl.load(in_ptr16 + idx_m), tmp4.shape)), tmp4, scatter_mask, sem='relaxed')

    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
```
## Key points
* We always accumulate to float 32 grad buffers regardless of the type in the forward. This is because we normally do all computation intra kernel w/ fp32 accumulation and we want the same behavior for atomic additions
* We are currently restricted to 1 scatter in the kenrel. I have some ideas on fx rewrites that would remove this restrictions but for now have nice error message w/ work around and will leave as a follow up.
* Will do more extensive performance/ memory profiling in a follow up.

### Toy E2E example
I have a toy E2E training example PR in the gym for now: https://github.com/pytorch-labs/attention-gym/pull/84/
I plan to update to a realistic learnable bias before landing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137452
Approved by: https://github.com/Chillee
2024-11-25 19:08:34 +00:00
chilli
c1f21bf2b6 Made FlexAttention error on subgraph lowering failure (#140331)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140331
Approved by: https://github.com/drisspg
2024-11-17 02:43:58 +00:00
PyTorch MergeBot
de34f581f1 Revert "Made FlexAttention error on subgraph lowering failure (#140331)"
This reverts commit e68bc76c28.

Reverted https://github.com/pytorch/pytorch/pull/140331 on behalf of https://github.com/malfet due to Looks like it regressed trunk, see 55f1959fc1/1 ([comment](https://github.com/pytorch/pytorch/pull/140331#issuecomment-2479435705))
2024-11-15 17:00:21 +00:00
chilli
e68bc76c28 Made FlexAttention error on subgraph lowering failure (#140331)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140331
Approved by: https://github.com/drisspg
2024-11-15 04:26:01 +00:00
Aaron Orenstein
d95aedf5fd [BE] typing for decorators - fx/_compatibility (part 1) (#134202)
Part of #134054.

This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
2024-08-22 17:07:33 +00:00
Oguz Ulgen
09f9c256ad Add basic mypy annotations to inductor (#132416)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132416
Approved by: https://github.com/XuehaiPan, https://github.com/jamesjwu
ghstack dependencies: #132415
2024-08-04 18:43:37 +00:00
PyTorch MergeBot
f2ddd5e9e0 Revert "Add basic mypy annotations to inductor (#132416)"
This reverts commit 78927d37f6.

Reverted https://github.com/pytorch/pytorch/pull/132416 on behalf of https://github.com/ZainRizvi due to Sorry, this PR has entered a weird state in the diff train. Trying to revert it to skip it, and then we can try relanding it ([comment](https://github.com/pytorch/pytorch/pull/132415#issuecomment-2267631785))
2024-08-04 18:39:29 +00:00
Oguz Ulgen
78927d37f6 Add basic mypy annotations to inductor (#132416)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132416
Approved by: https://github.com/XuehaiPan, https://github.com/jamesjwu
ghstack dependencies: #132415
2024-08-01 20:14:25 +00:00
PyTorch MergeBot
945bf78894 Revert "[BE] typing for decorators - fx/_compatibility (#131568)"
This reverts commit 193f62fde9.

Reverted https://github.com/pytorch/pytorch/pull/131568 on behalf of https://github.com/clee2000 due to same as https://github.com/pytorch/pytorch/pull/131572#issuecomment-2254328359 but I clicked the wrong link by accident.  This is where it actually starts ([comment](https://github.com/pytorch/pytorch/pull/131568#issuecomment-2254330781))
2024-07-28 03:43:39 +00:00
Aaron Orenstein
193f62fde9 [BE] typing for decorators - fx/_compatibility (#131568)
See #131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131568
Approved by: https://github.com/justinchuby, https://github.com/oulgen, https://github.com/zou3519
2024-07-25 22:24:19 +00:00
Yanbo Liang
5f3f14e5e4 [BE] Annotate subgraph_lowering (#131545)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131545
Approved by: https://github.com/anijain2305, https://github.com/zou3519
2024-07-25 04:35:26 +00:00
Xuehai Pan
b6d477fd56 [BE][Easy][16/19] enforce style for empty lines in import segments in torch/_i*/ (#129768)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129768
Approved by: https://github.com/jansel
2024-07-20 16:20:58 +00:00
Aaron Orenstein
afe15d2d2f Flip default value for mypy disallow_untyped_defs [3/11] (#127840)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127840
Approved by: https://github.com/oulgen
2024-06-08 18:28:01 +00:00
Peter Bell
24b64fc482 [HOP][inductor] Support pytrees as associative_scan input (#122137)
This allows `associative_scan` to take an arbitrary pytree of tensors,
which is flattened to their leaves before calling the `associative_scan`
higher order operator.

I also add support in inductor to generate code for scanning over sequences
of tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122137
Approved by: https://github.com/lezcano, https://github.com/Chillee
ghstack dependencies: #119430
2024-05-06 11:29:28 +00:00
Peter Bell
7ecbbc40c3 [HOP][inductor] Add higher order associative scan operator (#119430)
Currently only supports single tensor scans, e.g. `cumsum`, `cumprod`, `logcumsumexp`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119430
Approved by: https://github.com/Chillee
2024-04-23 14:40:13 +00:00