Commit Graph

35 Commits

Author SHA1 Message Date
Tugsbayasgalan Manlaibaatar
c933af2709 Switch to predispatch (#123573)
This PR switches export IR from aot-dispatch to pre-dispatch IR.

**What is pre-dispatch IR and why should you care?**

Currently the default IR returned by torch.export can contain only functional ATen operators after ALL pytorch dispatcher decompositions (for example, CompositeImplicitAutograd) run.

In contrast, pre-dispatch IR refers to an IR that can contain all functional ATen operators (i.e., not just from the core subset), before any decomposition happens, as well as operators that manipulate autograd state. Pre-dispatch IR closely resembles eager PyTorch computation, but is still functional and serializable by torch.export. As a result:
- You can train the pre-dispatch IR in eager mode as the IR contains necessary information for the autograd engine to automatically generate a backward graph.
- You can write sound graph transformations more easily as the IR is functional.
- Since it is an ATen IR, it is still normalized. For example, torch.add has multiple overloads, but aten.add.Tensor is unique in this IR.

If you want to get the core aten IR out of `torch.export`, you will need to:
```
ep = torch.export.export(M(), inputs)
ep_for_core_aten = ep.run_decompositions()
```

Differential Revision: [D56273267](https://our.internmc.facebook.com/intern/diff/D56273267)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123573
Approved by: https://github.com/gmagogsfm
2024-04-24 00:51:09 +00:00
drisspg
f4e2a226aa ScoreMod API (#121845)
# Summary

This PR adds a new higher-order_op: `templated_attention`.  This op is designed to extend the functionality of torch.nn.fucntional.scaled_dot_product_attention.  PyTorch has efficient pre-written fused-attention kernels. However, users want to modify how scores are computed (a substep inside attention) -- this traditionally requires the user to write their own attention kernel. One such modification to attention scores that is not currently supported by the top level SDPA op is:[ Attention with Linear Biases (ALiBi](https://arxiv.org/abs/2108.12409)).

This higher-order op will instead accept a callable( 'score_mod') function that is through torch.compile will be used to create an efficient attention kernel instantiation.

### Details

This HOP utilizes the existing fx and HOP infra to capture and convert the User `score-mod` function and convert to an FX graph module. Inductor then consumes this HOP that has a `ir.Subgraph` input. It will inline this lowered subgraph into a triton kernel which performs fused attention with the modification to the scores matrix inlined.

### API

The API for a score_mod function should be as follows:

```Python
def score_mod(score: torch.Tensor, batch: torch.Tensor, head: torch.Tensor, token_1: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```

This function receives five parameters:

- `score`: A scalar tensor representing the attention score, with the same data type and device as the query, key, and value tensors.
- `batch`, `head`, `seq_len_q`, `seq_len_kv`: Scalar tensors indicating the batch index, head index, query index, and key/value index, respectively, with torch.int data type and located on the same device as the score tensor.

Consider inputs query, key, value of shapes (2, 4, 16, 8), leading to an intermediate attention score matrix of shape (2, 4, 16, 16)

The score_mod function will be vectorized over each element of this matrix. For instance, modifying the score at the position corresponding to the 0th batch, 2nd head, between the 8th query and the 9th key element, would be invoked as:

```Python
score_mod(score[0,2,8,9], torch.tensor(0), torch.tensor(2), torch.tensor(8), torch.tensor(9))
```

### Examples
```Python
import torch
from torch.nn.attention.templated_attention import templated_attention

torch.manual_seed(0)

# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim)
query = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
key = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
value = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)

# Lets create a fun new score_modification! I will call this
# Checkerboard. It will reduce the score for neighboring tokens (1 step apart)
# in the sequence. And increase the score for tokens 2 steps apart. For everything
# else, the score will remain the same.

def checkerboard(score, batch, head, token_q, token_kv):
    score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
    score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
    return score

# Lets call templated_attention with this new score modification
output = templated_attention(query, key, value, score_mod=checkerboard)

compiled_templated_attention = torch.compile(templated_attention)
out_compiled = compiled_templated_attention(query, key, value, score_mod=checkerboard)

torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```

### Future Work
- This PR is currently only forward only. However the triton kernel for backwards where score_modifications to not rely on external buffers has been explored here: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/flash/flash_attention.py
- Kernel Improvements; There are has been some larger updates to the fused attention implementation that Triton uses in its tutorials. The implementation of this kernel is based on a prior version and should be updated.
- We may want to unify this API under the top level SDPA API and leave that as a follow up once this is more stable
- Should we error on CPU?
- There are some issues with dynamic shapes
- Capturing of free variables and lifting to inputs to the subgraph is not working correctly today

### Performance
Comparisons generated by this benchmark:

| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------|
| Average |     5.412 |              |             |             |             |            |               |                |
| Max     |     8.882 |           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |
| Min     |     3.645 |            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |
| Min     |     0.345 |            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |

For reference

| Configuration                                 | Forward Time (µ seconds) | Backend          | Speedup |
|-----------------------------------------------|--------------------------|------------------|---------|
| Fastest Config in Sweep (`8 16 4096 4096 64 relative_bias torch.bfloat16`) | 3608                   | Templated Attention                | 1.0  |
| Compiled SDPA (No Mask)                       | 9928                   | Math             | 2.75x   |
| Compiled SDPA (With Mask)                     | 11898                    | Math             | 3.29x   |
| Compiled SDPA (With Mask) | 8704                      | Memory Efficient Attention | 2.42x   |
| Compiled SDPA (No Mask) | 2548                     | FlashAttention2 | 0.706x   |

The speedups are measuring compiled templated attention speed versus different calls to torch.nn.functional.sdpa

<details>

<summary> FULL PERFORMANCE SWEEP NUMBERS </summary>

|   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |   eager_time |   compiled_time |   speedup |
|--------------|-------------|-------------|-------------|------------|---------------|----------------|--------------|-----------------|-----------|
|            1 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      331.444 |          67.221 |     4.931 |
|            1 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      335.300 |          64.187 |     5.224 |
|            1 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      352.039 |          63.806 |     5.517 |
|            1 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      371.699 |         711.349 |     0.523 |
|            1 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |      333.488 |          86.455 |     3.857 |
|            1 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |      322.363 |          82.469 |     3.909 |
|            1 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |      349.967 |          82.233 |     4.256 |
|            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |      486.359 |        1412.453 |     0.344 |
|            1 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |     2794.597 |         551.188 |     5.070 |
|            1 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |     3965.150 |         513.101 |     7.728 |
|            1 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |     2408.013 |         504.759 |     4.771 |
|            1 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |     6850.531 |       16733.675 |     0.409 |
|            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      441.939 |         123.576 |     3.576 |
|            8 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      560.379 |         116.710 |     4.801 |
|            8 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      421.172 |         115.825 |     3.636 |
|            8 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      994.492 |        2132.806 |     0.466 |
|            8 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     1436.430 |         309.495 |     4.641 |
|            8 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     1892.216 |         290.186 |     6.521 |
|            8 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     1360.665 |         282.956 |     4.809 |
|            8 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     3525.532 |        8359.702 |     0.422 |
|            8 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    22026.839 |        3864.604 |     5.700 |
|            8 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    31262.746 |        3609.551 |     8.661 |
|            8 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    20219.079 |        3480.402 |     5.809 |
|            8 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |    54654.647 |      116652.357 |     0.469 |
|           16 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      820.606 |         188.683 |     4.349 |
|           16 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |     1058.362 |         179.295 |     5.903 |
|           16 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      784.372 |         175.714 |     4.464 |
|           16 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |     1890.792 |        4212.877 |     0.449 |
|           16 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     2781.830 |         557.017 |     4.994 |
|           16 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     3694.050 |         525.249 |     7.033 |
|           16 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     2634.164 |         507.613 |     5.189 |
|           16 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     6959.917 |       15331.116 |     0.454 |
|           16 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    43889.096 |        7582.018 |     5.789 |
|           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    62784.293 |        7075.846 |     8.873 |
|           16 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    40308.606 |        6829.587 |     5.902 |
|           16 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |   108892.137 |      233090.953 |     0.467 |
</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121845
Approved by: https://github.com/Chillee, https://github.com/zou3519
2024-04-06 01:10:44 +00:00
drisspg
557e7c9c16 Add some type hints to functions and update a few spelling mistakes (#123015)
# Summary
While working on this PR: https://github.com/pytorch/pytorch/pull/121845
I found that these type hints made my ide/ noob experience easier to reason about

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123015
Approved by: https://github.com/Skylion007
2024-03-30 21:15:01 +00:00
rzou
c81c9ba472 Disallow {FakeTensor,FunctionalTensor}.data_ptr (#122514)
This PR:
- disallows FakeTensor.data_ptr when it is called inside PT2 or fx tracing.
- disallows FunctionalTensor.data_ptr (python FunctionalTensor is only used in
  PT2)

The motivation behind this is that the leading cause of segfaults when
using custom ops with PT2 is calling .data_ptr on FunctionalTensor or
FakeTensor.

This change is BC-breaking. If your code broke as a result of this, it's
because there was a bug in it (these .data_ptr should never be
accessed!). You can either fix the bug (recommended) or get the previous
behavior back with:
```
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor

data_ptr = 0 if isinstance(tensor, (FakeTensor, FunctionalTensor)) else tensor.data_ptr()
```

Test Plan:
- existing tests

Differential Revision: [D55366199](https://our.internmc.facebook.com/intern/diff/D55366199)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122514
Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/yifuwang, https://github.com/kurtamohler
2024-03-26 23:55:42 +00:00
Tugsbayasgalan (Tugsuu) Manlaibaatar
5b7ceab650 Support auto_functionalize in pre-dispatch (#122177)
Summary: Title

Test Plan: CI

Differential Revision: D55042061

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122177
Approved by: https://github.com/zou3519
2024-03-20 04:17:58 +00:00
albanD
6791b0c09e Change default torch_function behavior to be disabled when torch_dispatch is defined (take 2) (#120632)
This does not introduce a new test but is tested by checking that all the classes we already have still behave as before now that they don't explicitly disable torch_function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120632
Approved by: https://github.com/ezyang
2024-03-09 01:08:37 +00:00
Tugsbayasgalan Manlaibaatar
4f120dc2a6 Clean up mode handling in python dispatcher (#121083)
Things that were bad before this PR:
1. Temporarily unsetting functional tensor mode and proxy mode both had duplicate implementation
2. There are variants of mode handling private utils that has duplicate implementation. (different APIs calling repeated implementation, so i refactored)
3. _push_mode API used to take dispatch key argument which is not necessary.
4. There are unused APIs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121083
Approved by: https://github.com/zou3519
2024-03-08 00:30:34 +00:00
PaulZhang12
c66d68ba51 [PT2] Add tolist() to FunctionalTensor for torch.export (#121242)
Adding tolist() to FunctionalTensor for torch.exporting TorchRec data types
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121242
Approved by: https://github.com/ezyang
2024-03-06 18:10:44 +00:00
angelayi
a7e93c341f [hoo] Add with_effects to handle side effectful ops (#120296)
Proposal: https://docs.google.com/document/d/179QyhicGzTXJ5jvTAoAosP_Nzgf3PpgZwU_E3VV9PlM/edit#heading=h.bnm38nu3yfno
Implementation discussion: https://docs.google.com/document/d/179QyhicGzTXJ5jvTAoAosP_Nzgf3PpgZwU_E3VV9PlM/edit#heading=h.bj61609o1buq

Result with print:
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %with_effects : [num_users=1] = call_function[target=torch._higher_order_ops.effects.with_effects](args = (%arg0_1, aten.print.default, moo), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg1_1, %arg1_1), kwargs = {})
    return [getitem, add]
```

Follow ups:
* Add handling to auto_functionalize
* Add support for tokens on the export side
* Add support for tokens on the inductor side

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120296
Approved by: https://github.com/zou3519
2024-03-05 08:58:32 +00:00
Jason Ansel
0e0a621e0c [dynamo] Minor refactors (#120966)
These are changes I pulled out of the above PRs due to not being
related.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120966
Approved by: https://github.com/yanboliang
2024-03-03 02:20:48 +00:00
Tugsbayasgalan Manlaibaatar
c646030cd2 Support higher order op functionalization in predispatch IR (#115314)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115314
Approved by: https://github.com/bdhirsh
2024-03-01 09:13:47 +00:00
Tugsbayasgalan (Tugsuu) Manlaibaatar
8646872ff7 Make balance_gradient preserved in export (#120332)
Summary: We can only not-decompose CompositeImplicit functional custom ops. From the looks of the implementation, this op looks functional. So the fix is just fixing the schema.

Test Plan: CI

Differential Revision: D54019265

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120332
Approved by: https://github.com/zhxchen17
2024-02-23 19:14:08 +00:00
Tugsbayasgalan Manlaibaatar
fa1e89b337 Ban mutation on dropout outputs in export (#117879)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117879
Approved by: https://github.com/ezyang
ghstack dependencies: #117811
2024-01-21 04:53:40 +00:00
Tugsbayasgalan Manlaibaatar
81f98f1082 Experimental non-strict mode (#114658)
This is proof-of-concept implementation of how people can use a marker `mark_strict` to enable torchdynamo while exporting under non-strict mode. The main idea is that `mark_strict` will turn into an HOO which then utilizes dynamo to do correctness analysis in the same way how torch.cond works today. There are some notable limitations:
1. This API is not meant for public use yet
2. Strict region can't work with arbitrary container inputs
3. We don't preserve `nn_module_stack` and other node metadata for the strict region.
4. strict_mode HOO will show up in the final graph. This is undesirable in the long term, but for short term experiments, it should be good enough. Will fix this in the follow up PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114658
Approved by: https://github.com/ydwu4
2024-01-04 12:24:58 +00:00
Tugsbayasgalan Manlaibaatar
dfc898ede4 Don't decompose functional ops in predispatch functionalization (#116383)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116383
Approved by: https://github.com/bdhirsh
ghstack dependencies: #115188, #115210
2023-12-28 11:54:04 +00:00
Tugsbayasgalan Manlaibaatar
76b1d44d57 pre_dispatch aot_export (#115188)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115188
Approved by: https://github.com/bdhirsh
2023-12-25 04:51:21 +00:00
PyTorch MergeBot
0567f71ac6 Revert " pre_dispatch aot_export (#115188)"
This reverts commit a267d67350.

Reverted https://github.com/pytorch/pytorch/pull/115188 on behalf of https://github.com/jeanschmidt due to sadly, it is required to revert this commit in order to revert https://github.com/pytorch/pytorch/pull/115454 ([comment](https://github.com/pytorch/pytorch/pull/115188#issuecomment-1866310014))
2023-12-21 14:03:18 +00:00
Tugsbayasgalan Manlaibaatar
a267d67350 pre_dispatch aot_export (#115188)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115188
Approved by: https://github.com/bdhirsh
2023-12-20 21:36:25 +00:00
Tugsbayasgalan Manlaibaatar
d85314c95c Support Predispatch functionalization (#113728)
In this PR, we are implementing Functionalization on pre-dispatch graph. Today, every dispatch key except for Dispatchkey.Python has a dedicated mode stack in python. PreDispatch tracing relies on this behaviour by pushing ProxyTorchDispatchMode to Dispatchkey.PreDispatch mode stack and handle the dispatching logic in python. To make pre-dispatch functionalization work, we now need to push FunctionalTensorMode on DispatchKey.PreDispatch mode stack and make sure it runs before ProxyTorchDispatchMode. (this is very similar to how post-dispatch tracing work). Here are some design decisions we made for this flow to work:

1. FunctionalTensorMode internally calls C++ functionalize key. Since C++ functionalization goes after PreDispatch, if we are not careful, we will keep re-entering into PreDispatch key. We solve this by directly dispatching to C++ Functionalize key.

2. We delete mode_stack_per_key logic because the only realistic time it is exercised is for PreDispatch and it is in general not safe to have a plain list because FunctionalTensorMode and ProxyTorchDispatchMode ordering matter and it is hard to enforce it on plain list. Instead, now we have a private class that tracks PreDispatch mode stack.

3.  We will still run CompositeImplicitAutograd decomps in this PR, and disable this logic later as a followup.

Some missing bits after this PR:
1. Preserving autograd ops in a functional form. Right now they still show up in the graph but in a "non-functional" way.
2. Turn off CompositeImplicitAutograd decomps
3. Functionalizing HOO

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113728
Approved by: https://github.com/bdhirsh
2023-12-19 20:28:35 +00:00
rzou
99257002fa Extend auto_functionalized to support ops that return Tensors (#115135)
We can auto-functionalize operators that mutate their inputs as long as
the outputs of the operator do not alias their inputs. The user needs
to provide an abstract impl for the operator if it has non-trivial
returns.
- We update can_auto_functionalize(op) to include ops that return (but
  do not alias) Tensors
- We update auto_functionalized(op, mutated_args_names, kwargs) to
  return (out, mutated_args), where `out = op(**kwargs)` and
  `mutated_args` are the new values of the inputs that would have been
  mutated.

Test Plan:
- new test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115135
Approved by: https://github.com/bdhirsh
ghstack dependencies: #114955, #114956, #115134
2023-12-05 22:43:06 +00:00
rzou
cfa4370c07 torch.compile should auto-functionalize certain mutable ops (#114955)
Users may wish to torch.compile custom ops that mutate their inputs
and return nothing (this is a common class of operators).
torch.compile will automatically support this op without anyone needing
to provide a functionalization kernel for it. Here's how.

Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> ()
op. First, when FakeTensor sees this op, it can just return None.
This is the case because custom ops are not allowed to mutate input
metadata, so the FakeTensor rule for one that returns nothing is trivial.

Next, when Python FunctionalTensor sees the op, it will functionalize
it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...})
HOP and replacing the mutated inputs with the outputs of this HOP.
This HOP effectively runs the functional version of the op when
called: it clones inputs that will be mutated, runs the op, and
then returns Tensors with the new values.

In the future we can teach Inductor how to do re-inplacing when it sees
this HOP (like how triton kernels do it) but this isn't urgent (and is
more of a performance problem).

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114955
Approved by: https://github.com/bdhirsh
2023-12-05 14:53:08 +00:00
Yukio Siraichi
132cb57e47 Skip aliasing correction for lift_fresh. (#112202)
Fix: #111506

This PR skips aliasing correction on `lift_fresh` calls. Reasoning is: although unlifted and lifted tensors are technically aliases, they are from different levels of abstraction (`FunctionalTensorWrapper` and `XLATensor`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112202
Approved by: https://github.com/bdhirsh
2023-11-03 20:46:30 +00:00
Peter Bell
66c32d099a Use pytree.arg_tree_leaves everywhere (#112394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112394
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392, #112393
2023-10-31 15:57:06 +00:00
Peter Bell
bbd5b935e4 Use pytree.tree_leaves everywhere (#112324)
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
2023-10-30 03:39:04 +00:00
Oguz Ulgen
4e310fd875 [Autograd] Track when mutations are for triton kernels (#111500)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111500
Approved by: https://github.com/bdhirsh
2023-10-19 15:34:34 +00:00
Brian Hirsh
4cf23c6a61 FunctionalTensor: avoid spurious not_implemented logging during proxy tracing (#111040)
This is kind of hard to test, but I can try to add a test case if requested.

I noticed locally that we now end up logging to the ProxyTensorMode and FakeTensorMode `not_implemented` logs in very simple compile examples: https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py#L269

It was because `_mirror_autograd_meta_to()` indirectly queries sizes, and since modes have higher priority than subclasses, `aten::sym_sizes()` was getting dispatched to our modes before going to `FunctionalTensor.__torch_dispatch__`.

This works out fine (they return NotImplemented and we eventually get to `FunctionalTensor`) but I figured we want to avoid cluttering up the logs. So I wrapped the calls with `FunctionalTensorMode`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111040
Approved by: https://github.com/ezyang
2023-10-16 15:18:20 +00:00
Oguz Ulgen
f04b1a0d27 [AOTInductor] Implement autograd eager backend for native triton kernels (#110403)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110403
Approved by: https://github.com/zou3519, https://github.com/bdhirsh
2023-10-04 17:56:56 +00:00
Brian Hirsh
b457e3f79a Reland attempt 2 of "Update AOTAutograd to use FunctionalTensorMode instead of C++ functionalization (#106406)" (#109906)" (#110079)
The first reland broke internal (failing diff: D49617462).

The major error looks like it's because there's an internal-only higher order op that needs a new functionalization rule. I'm going to land an internal diff for that and confirm tests pass before relanding this PR.

Also confirmed that the issue from https://github.com/pytorch/pytorch/issues/110121 is fixed, and added a test.

This reverts commit 1b90f07f5a.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110079
Approved by: https://github.com/ezyang
2023-10-03 18:50:25 +00:00
Brian Hirsh
63526a63f5 Make FunctionalTensor subclass to be more like functorch (interaction with ZeroTensor + Conjugate key) (#109023)
I added some tests for Conj, Neg and ZeroTensor for both python and C++ functionalization. This also fixes a nasty segfult when running a functorch `jacfwd` test with `torch.compile`, once AOTAutograd is using `FunctionalTensor`.

Changes:

(1) I use Jeffrey's `make_wrapper_subclass(extra_dispatch_keys)` kwarg to plumb extra dispatch keys ontoto the wrapper, mirroring what C++ functionalization does (C++ functionalization will mirror all dispatch keys from the inner tensor to the wrapper, except for python and functorch keys).

(2) FunctionalTensorMode will decompose CompositeImplicitAutograd ops, since (for example) ZeroTensor kernels can send ops like `.to()` directly to the Python key. We'll need a way to toggle this later for pre-dispatch functionalization

(3) Bound `_ForceDispatchKeyGuard` and BatchedTensorImpl's dispatch keyset to python

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109023
Approved by: https://github.com/zou3519
ghstack dependencies: #108654, #109662, #109632
2023-09-22 07:09:04 +00:00
hauntsaninja
2cd0b94533 Hide __getattr__ from type checkers (#109683)
Visibility of this causes type checkers to conservatively assume that all attributes are defined on torch module.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109683
Approved by: https://github.com/ngimel, https://github.com/ezyang, https://github.com/malfet
2023-09-21 17:01:23 +00:00
Brian Hirsh
238fb66085 python functionalization: support higher order ops (#108656)
We now have two types of functionalization, C++ Functionalization (through the `Functionalize` dispatch key), and python functionalization (through the `FunctionalTensorMode` torch_dispatch mode).

This means that all higher order ops need custom functionalization rules for the python variant too. I added them here, as well as a helper function `dispatch_functionalize()` - equivalent to `torch.func.functionalize()`, except that it uses `FunctionalTensorMode`.

In theory we could have secretly switched `torch.func.functionalize` to use `FunctionalTensorMode`. This would be BC-breaking, though, since `FunctionalTensorMode` isn't composable with the other functorch transforms (the functorch layer-mode stack doesn't know how to re-order torch_dispatch modes arbitrarily).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108656
Approved by: https://github.com/zou3519
ghstack dependencies: #109024, #109248
2023-09-20 04:37:31 +00:00
Brian Hirsh
25e81f19f3 reland "python functionalization: add helpers, functionalize_sync and mirror_autograd_meta (#107917)" (#109518)
Reland - the previous PR was reverted by internal with this error:
```
  File "/data/sandcastle/boxes/eden-trunk-hg-fbcode-fbsource/buck-out/v2/gen/fbcode/363cd7e240f5d021/caffe2/torch/fb/trainer/data_modules/tests/__test_dataloader__/test_dataloader#link-tree/torch/__init__.py", line 29, in <module>
    from ._utils_internal import _functionalize_sync as _sync
ImportError: cannot import name '_functionalize_sync' from 'torch._utils_internal'
```

I couldn't figure out why internal was unhappy with the import. One potential reason is that I see a build rule for *another* `_utils_internal.py` in the fb folder here ([link](https://www.internalfb.com/code/fbsource/[30ed85cd88409af98b7490be137aaa5dfd7afd01]/fbcode/caffe2/TARGETS?lines=444))

Rather than burn more time investigating, I confirmed internally that the error goes away if I move the util from `torch/_utils_internal.py` to `torch/_utils.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109518
Approved by: https://github.com/albanD
2023-09-19 13:25:24 +00:00
PyTorch MergeBot
49b18ae546 Revert "python functionalization: add helpers, functionalize_sync and mirror_autograd_meta (#107917)"
This reverts commit 0ad595954a.

Reverted https://github.com/pytorch/pytorch/pull/107917 on behalf of https://github.com/clee2000 due to breaking internal builds D49346637 ([comment](https://github.com/pytorch/pytorch/pull/107917#issuecomment-1722566885))
2023-09-17 20:57:41 +00:00
Brian Hirsh
0ad595954a python functionalization: add helpers, functionalize_sync and mirror_autograd_meta (#107917)
Added two new utils to help with turning python functionalization on in AOTAutograd (next PR):

(1) updated `torch._sync()`. Previously, this API could only handle `torch.Tensor` instances that had a `FunctionalTensorWrapper` TensorImpl. It now needs to handle python `FunctionalTensor`'s. In theory I can probably break BC and change this API (since it's private?), but I decided not to do it in this PR stack do minimize the chance of reverts. Instead of updating that API directly (which is in C++), I just added a python shim that first tries to unwrap the python `FunctionalTensor` if there is one, then calls the existing C++ logic

(2) `mirror_autograd_meta` is now a standalone API that tries to mirror the `requires_grad` and `is_leaf` autograd metadata from one tensor to another. Previously this was hardcoded into `torch._to_functional_tensor()`. But I now need to use it in a more standalone way: later in AOTAutograd when we unwrap and re-wrap a tensor subclasses, we need to manually mirror the autograd metadata from the original to the updated version of the subclass.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107917
Approved by: https://github.com/ezyang
ghstack dependencies: #106404
2023-09-15 20:19:25 +00:00
Brian Hirsh
f22b303f65 Add TorchDispatch version of functionalization (#106404)
This PR adds a new `FunctionalTensor` subclass, and `FunctionalTensorMode` torch dispatch mode. Together, this class/mode are a lightweight wrapper around our existing C++ functionalization logic.

This idea came from Ed - later in the stack, I want to be able to run functionalization **underneath** torch_dispatch, when performing tracing in AOTAutograd. I can't do this easily with vanilla C++ functionalization, because it has a dedicated dispatch key that always runs before TorchDispatch. However, by adding a torch_dispatch mode shim around functionalization, we can use functionalization as a torch_dispatch mode, which will make it easier to run underneath other modes later.

This PR provides the basic new classes, and some light testing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106404
Approved by: https://github.com/ezyang
2023-09-15 20:19:25 +00:00