mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
```
import torch
from torch import nn
import torch.nn.functional as F
import torch._inductor.config as config
# torch.set_default_device('cuda')
import torch
from torch.nn.attention._templated_attention import _templated_attention as templated_attention
from triton.testing import do_bench
from torch.nn.attention import SDPBackend, sdpa_kernel
index = torch.ops.aten
torch.manual_seed(0)
B = 16
H = 16
S = 2048
D = 64
head_scale = torch.randn(H, device='cuda')
def alibi(score, batch, head, token_q, token_kv):
return score + torch.ops.aten.index(head_scale, [head]) * (token_q - token_kv)
bias = torch.randn(H, S, S, dtype=torch.float16, device='cuda')
query = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
key = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
value = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
compiled = torch.compile(templated_attention)
out = compiled(query, key, value, score_mod=alibi)
out2 = templated_attention(query, key, value,score_mod=alibi)
print((out - out2).abs().mean())
assert (out - out2).abs().mean() < 1e-3
print("Flash (no mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value)))
print("Flash (mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value, attn_mask=bias)))
print("flexattention: ", do_bench(lambda: compiled(query, key, value, score_mod=alibi)))
```
<img width="324" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/18c175d0-2720-4dfd-8747-85b8a8f609f5">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124444
Approved by: https://github.com/jansel, https://github.com/drisspg
|
||
|---|---|---|
| .. | ||
| backends | ||
| repro | ||
| variables | ||
| __init__.py | ||
| _trace_wrapped_higher_order_op.py | ||
| bytecode_analysis.py | ||
| bytecode_transformation.py | ||
| cache_size.py | ||
| callback.py | ||
| code_context.py | ||
| codegen.py | ||
| compiled_autograd.py | ||
| comptime.py | ||
| config.py | ||
| convert_frame.py | ||
| create_parameter_op.py | ||
| current_scope_id.py | ||
| debug_utils.py | ||
| decorators.py | ||
| device_interface.py | ||
| eval_frame.py | ||
| exc.py | ||
| external_utils.py | ||
| funcname_cache.py | ||
| guards.py | ||
| hooks.py | ||
| logging.py | ||
| mutation_guard.py | ||
| output_graph.py | ||
| polyfill.py | ||
| profiler.py | ||
| replay_record.py | ||
| resume_execution.py | ||
| side_effects.py | ||
| source.py | ||
| symbolic_convert.py | ||
| tensor_version_op.py | ||
| test_case.py | ||
| test_minifier_common.py | ||
| testing.py | ||
| trace_rules.py | ||
| types.py | ||
| utils.py | ||