pytorch/torch/_dynamo/variables
chilli e0c5113dec Add support for capturing tensors with score_mod (#124444)
```
import torch
from torch import nn
import torch.nn.functional as F
import torch._inductor.config as config
# torch.set_default_device('cuda')

import torch
from torch.nn.attention._templated_attention import _templated_attention as templated_attention
from triton.testing import do_bench
from torch.nn.attention import SDPBackend, sdpa_kernel

index = torch.ops.aten
torch.manual_seed(0)

B = 16
H = 16
S = 2048
D = 64

head_scale = torch.randn(H, device='cuda')
def alibi(score, batch, head, token_q, token_kv):
    return score + torch.ops.aten.index(head_scale, [head]) * (token_q - token_kv)
bias = torch.randn(H, S, S, dtype=torch.float16, device='cuda')

query = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
key = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
value = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)

compiled = torch.compile(templated_attention)
out = compiled(query, key, value, score_mod=alibi)
out2 = templated_attention(query, key, value,score_mod=alibi)
print((out - out2).abs().mean())
assert (out - out2).abs().mean() < 1e-3
print("Flash (no mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value)))
print("Flash (mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value, attn_mask=bias)))
print("flexattention: ", do_bench(lambda: compiled(query, key, value, score_mod=alibi)))
```
<img width="324" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/18c175d0-2720-4dfd-8747-85b8a8f609f5">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124444
Approved by: https://github.com/jansel, https://github.com/drisspg
2024-04-23 06:20:13 +00:00
..
__init__.py [dynamo] Improve constant-prop for regex/torch.__version__ (#123705) 2024-04-12 19:03:13 +00:00
base.py [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261) 2024-04-17 19:29:34 +00:00
builder.py [BE]: Update ruff to 0.4.1 (#124549) 2024-04-21 14:06:23 +00:00
builtin.py [dynamo] support object.__setattr__(obj, name, value) (#124068) 2024-04-17 15:57:14 +00:00
constant.py [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261) 2024-04-17 19:29:34 +00:00
ctx_manager.py [dynamo] Support warnings.catch_warnings (#123511) 2024-04-08 22:27:46 +00:00
dicts.py [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261) 2024-04-17 19:29:34 +00:00
distributed.py Don't create world pg variable out of thin air when rewriting c10d collectives (#122561) 2024-03-26 20:12:08 +00:00
functions.py [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261) 2024-04-17 19:29:34 +00:00
higher_order_ops.py Add support for capturing tensors with score_mod (#124444) 2024-04-23 06:20:13 +00:00
iter.py [dynamo] Improve constant-prop for regex/torch.__version__ (#123705) 2024-04-12 19:03:13 +00:00
lazy.py [dynamo] Replace VariableTracker.apply with visit/realize_all (#122218) 2024-03-20 07:53:18 +00:00
lists.py [Dynamo] Fix NamedTuple hasattr bug (#124531) 2024-04-21 04:36:22 +00:00
misc.py Introduce set_example_value and use it throughout Dynamo (#124176) 2024-04-17 22:57:11 +00:00
nn_module.py [dynamo][easy] forbid_in_graph check to use getattr_static (#124445) 2024-04-20 14:11:05 +00:00
optimizer.py Defer marking_static_address (#124309) 2024-04-19 17:20:57 +00:00
sdpa.py [dynamo] Refactor reconstruct() not to return anything (#120150) 2024-02-17 17:13:41 +00:00
tensor.py Introduce set_example_value and use it throughout Dynamo (#124176) 2024-04-17 22:57:11 +00:00
torch_function.py [dynamo] Optimize SourcelessBuilder (#122063) 2024-03-19 04:23:30 +00:00
torch.py [dynamo] support object.__setattr__(obj, name, value) (#124068) 2024-04-17 15:57:14 +00:00
user_defined.py Enable SourcelessBuilder to build GraphModule generated by make_fx (#123673) 2024-04-19 17:23:51 +00:00