pytorch/torch/_dynamo
chilli e0c5113dec Add support for capturing tensors with score_mod (#124444)
```
import torch
from torch import nn
import torch.nn.functional as F
import torch._inductor.config as config
# torch.set_default_device('cuda')

import torch
from torch.nn.attention._templated_attention import _templated_attention as templated_attention
from triton.testing import do_bench
from torch.nn.attention import SDPBackend, sdpa_kernel

index = torch.ops.aten
torch.manual_seed(0)

B = 16
H = 16
S = 2048
D = 64

head_scale = torch.randn(H, device='cuda')
def alibi(score, batch, head, token_q, token_kv):
    return score + torch.ops.aten.index(head_scale, [head]) * (token_q - token_kv)
bias = torch.randn(H, S, S, dtype=torch.float16, device='cuda')

query = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
key = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
value = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)

compiled = torch.compile(templated_attention)
out = compiled(query, key, value, score_mod=alibi)
out2 = templated_attention(query, key, value,score_mod=alibi)
print((out - out2).abs().mean())
assert (out - out2).abs().mean() < 1e-3
print("Flash (no mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value)))
print("Flash (mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value, attn_mask=bias)))
print("flexattention: ", do_bench(lambda: compiled(query, key, value, score_mod=alibi)))
```
<img width="324" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/18c175d0-2720-4dfd-8747-85b8a8f609f5">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124444
Approved by: https://github.com/jansel, https://github.com/drisspg
2024-04-23 06:20:13 +00:00
..
backends [dynamo] Return gm.forward for eager backend (#124109) 2024-04-20 14:11:05 +00:00
repro [minifier] Add config flag to ignore non-fp values (#123006) 2024-04-09 03:34:09 +00:00
variables Add support for capturing tensors with score_mod (#124444) 2024-04-23 06:20:13 +00:00
__init__.py [torch.compile] Provide capability to register callback on compile start/stop (#120764) 2024-02-29 07:37:52 +00:00
_trace_wrapped_higher_order_op.py [Compiled Autograd] Introduce BackwardState capture (#120382) 2024-02-28 20:36:47 +00:00
bytecode_analysis.py [dynamo] fix call_finally issue in Python 3.8 (#124122) 2024-04-16 08:36:20 +00:00
bytecode_transformation.py [dynamo, 3.12] fix positions and offsets of added instructions when we clean (#123991) 2024-04-14 03:58:04 +00:00
cache_size.py Chore: improve log message about cache size limit exceeded (#116557) 2024-01-17 06:07:18 +00:00
callback.py [torch.compile] Provide capability to register callback on compile start/stop (#120764) 2024-02-29 07:37:52 +00:00
code_context.py [dynamo] preserve some FX node metadata of GraphModules (#107067) 2023-09-15 23:29:14 +00:00
codegen.py [dynamo] Support custom __setattr__ on UserDefinedObjectVariable (#123318) 2024-04-07 21:06:52 +00:00
compiled_autograd.py [compiled autograd][dynamo] Make compiled graph take in boxed inputs (#122353) 2024-04-12 10:29:09 +00:00
comptime.py Make torch._dynamo.mark_static work inside graph (#118962) 2024-02-02 20:01:27 +00:00
config.py [dynamo][cpp-guard] Reland Attempt 1 - Enable cpp guard manager (#124231) 2024-04-18 06:36:20 +00:00
convert_frame.py rename sl to strobelight (#124455) 2024-04-19 22:50:13 +00:00
create_parameter_op.py [dynamo] Add support for nn.Parameter constructor (part 2) (#120965) 2024-03-16 04:29:58 +00:00
current_scope_id.py [HigherOrderOp] Fall back on all new side effects in speculate_subgraph (#104077) 2023-06-28 14:20:37 +00:00
debug_utils.py [dynamo] Forward OptimizedModule.__setattr__ to the wrapped module (#122098) 2024-04-01 14:30:44 +00:00
decorators.py feat: Add min, max ranges to mark_dynamic API (#119737) 2024-03-07 23:26:03 +00:00
device_interface.py [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261) 2024-04-17 19:29:34 +00:00
eval_frame.py Assert that TracingContext is available when set_example_value is called (#124284) 2024-04-21 11:23:13 +00:00
exc.py [Dynamo] fix opcode YIELD_FROM and SEND (#123912) 2024-04-12 21:57:47 +00:00
external_utils.py [dynamo] Support module backwards hooks (#120685) 2024-03-01 02:24:26 +00:00
funcname_cache.py [dynamo] Enable typechecking for funcname_cache.py (#112031) 2023-10-26 04:54:16 +00:00
guards.py [dynamo] Graph break on uninitialized nn.Module (#123790) 2024-04-12 19:03:13 +00:00
hooks.py [dynamo] Enable typechecking for hooks.py (#112565) 2023-11-04 19:37:06 +00:00
logging.py Add stack trace to "start tracing" log (#118217) 2024-01-25 06:53:12 +00:00
mutation_guard.py [dynamo] Config option to Inline builtin nn module forward (#122725) 2024-03-28 03:01:27 +00:00
output_graph.py [Export] Add runtime assert to non-strict export (#123681) 2024-04-18 16:13:27 +00:00
polyfill.py [dynamo] Improve constant-prop for regex/torch.__version__ (#123705) 2024-04-12 19:03:13 +00:00
profiler.py Remove size asserts from fx_insert_profiling (#114830) 2023-12-04 19:08:36 +00:00
replay_record.py [CI] Install dill in ci (#116214) 2024-01-24 23:42:35 +00:00
resume_execution.py [dynamo, 3.12] handle possibility of NULL local variables during graph breaks (#124095) 2024-04-16 08:44:43 +00:00
side_effects.py [dynamo] Update co_names if needed in fix_vars (#123697) 2024-04-11 01:00:01 +00:00
source.py [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261) 2024-04-17 19:29:34 +00:00
symbolic_convert.py Revert "[Dynamo] Check for __bool__ attribute before accessing it (#120943)" 2024-04-18 06:34:32 +00:00
tensor_version_op.py Support torchbind op dispatch in python (#123367) 2024-04-19 17:17:27 +00:00
test_case.py [dynamo, 3.12] enable dynamo on 3.12, enable most dynamo unittests on 3.12 (#123216) 2024-04-04 20:00:54 +00:00
test_minifier_common.py Enable possibly-undefined error code (#118533) 2024-01-30 21:07:01 +00:00
testing.py [dynamo] Return gm.forward for eager backend (#124109) 2024-04-20 14:11:05 +00:00
trace_rules.py preferred blas library; cublaslt gemm implementation (#122106) 2024-04-22 15:38:22 +00:00
types.py [dynamo] delete dynamo cache entry when guard function is invalidated [attempt 2] (#119107) 2024-02-07 03:32:42 +00:00
utils.py Assert that TracingContext is available when set_example_value is called (#124284) 2024-04-21 11:23:13 +00:00