pytorch/test/inductor
chilli e0c5113dec Add support for capturing tensors with score_mod (#124444)
```
import torch
from torch import nn
import torch.nn.functional as F
import torch._inductor.config as config
# torch.set_default_device('cuda')

import torch
from torch.nn.attention._templated_attention import _templated_attention as templated_attention
from triton.testing import do_bench
from torch.nn.attention import SDPBackend, sdpa_kernel

index = torch.ops.aten
torch.manual_seed(0)

B = 16
H = 16
S = 2048
D = 64

head_scale = torch.randn(H, device='cuda')
def alibi(score, batch, head, token_q, token_kv):
    return score + torch.ops.aten.index(head_scale, [head]) * (token_q - token_kv)
bias = torch.randn(H, S, S, dtype=torch.float16, device='cuda')

query = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
key = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
value = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)

compiled = torch.compile(templated_attention)
out = compiled(query, key, value, score_mod=alibi)
out2 = templated_attention(query, key, value,score_mod=alibi)
print((out - out2).abs().mean())
assert (out - out2).abs().mean() < 1e-3
print("Flash (no mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value)))
print("Flash (mask): ", do_bench(lambda: F.scaled_dot_product_attention(query, key, value, attn_mask=bias)))
print("flexattention: ", do_bench(lambda: compiled(query, key, value, score_mod=alibi)))
```
<img width="324" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/18c175d0-2720-4dfd-8747-85b8a8f609f5">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124444
Approved by: https://github.com/jansel, https://github.com/drisspg
2024-04-23 06:20:13 +00:00
..
cpp
extension_backends [Inductor] [ReImplement] Outer Loop Fusion for CPP Backend (#121625) 2024-04-05 06:24:57 +00:00
__init__.py
indirect_assert_helper.py [Inductor Intel GPU backend Upstream] Generalize part of Inductor test case (#117513) 2024-01-18 08:26:21 +00:00
minifier_smoke.py
opinfo_harness.py
test_aot_inductor_utils.py [aoti] Change aot_compile callsites (#122225) 2024-03-29 21:34:20 +00:00
test_aot_inductor.py [inductor] Refactor runtime files into torch._inductor.runtime (part 3) (#124557) 2024-04-22 18:46:24 +00:00
test_benchmark_fusion.py Defer selection of triton template (#120275) 2024-03-20 01:40:33 +00:00
test_binary_folding.py Enable FX graph cache for a batch of inductor tests (#121696) 2024-03-14 03:39:59 +00:00
test_codecache.py [inductor] Refactor runtime files into torch._inductor.runtime (part 3) (#124557) 2024-04-22 18:46:24 +00:00
test_codegen_triton.py Enable FX graph caching in another batch of inductor tests (#121697) 2024-03-15 19:38:51 +00:00
test_compiled_autograd.py [aot] trim refcount for subclass runtime wrapper (#124155) 2024-04-18 02:34:52 +00:00
test_compiled_optimizers.py Enable dynamo traced test_forloop_goes_right_direction (#123322) 2024-04-18 00:50:10 +00:00
test_config.py Enable FX graph cache for a batch of inductor tests (#121696) 2024-03-14 03:39:59 +00:00
test_control_flow.py Add torch.while_loop support to AOT Inductor (#123586) 2024-04-09 22:53:10 +00:00
test_coordinate_descent_tuner.py [inductor] Use compile time config values in runtime (#124561) 2024-04-22 18:46:40 +00:00
test_cpu_cpp_wrapper.py [AOTI] Fixes ScatterFallback codegen (#124580) 2024-04-22 20:47:26 +00:00
test_cpu_repro.py [Inductor] Enable VecMask store (#123710) 2024-04-23 00:29:47 +00:00
test_cuda_cpp_wrapper.py Revert "fix Invalid call to aoti_torch_tensor_copy_ #123039 (#124037)" 2024-04-22 07:20:10 +00:00
test_cuda_repro.py [inductor] Refactor runtime files into torch._inductor.runtime (part 2) (#124553) 2024-04-22 18:46:20 +00:00
test_cudacodecache.py Enable FX graph caching in another batch of inductor tests (#121697) 2024-03-15 19:38:51 +00:00
test_cudagraph_trees.py [CUDAGraphTree] Support mutated inputs from prior cudagraph pool (#123231) 2024-04-19 10:32:12 +00:00
test_custom_lowering.py Enable FX graph caching in another batch of inductor tests (#121697) 2024-03-15 19:38:51 +00:00
test_custom_post_grad_passes.py Add custom joint graph passes (#124443) 2024-04-19 11:54:46 +00:00
test_cutlass_backend.py [Inductor cutlass backend] Fix tests: skipIfROCm always skips when using as class annotation (#123930) 2024-04-22 13:59:59 +00:00
test_debug_trace.py Enable FX graph cache for a batch of inductor tests (#121696) 2024-03-14 03:39:59 +00:00
test_decompose_mem_bound_mm.py realize inputs to mem bound mm decomposition (#123165) 2024-04-18 23:10:04 +00:00
test_dependencies.py Enable FX graph caching in another batch of inductor tests (#121697) 2024-03-15 19:38:51 +00:00
test_distributed_patterns.py [compiled autograd][dynamo] Make compiled graph take in boxed inputs (#122353) 2024-04-12 10:29:09 +00:00
test_efficient_conv_bn_eval.py Implement efficient_conv_bn_eval_decomp_graph_transform to handle conv and bn fusion after decomp (#123680) 2024-04-19 00:22:25 +00:00
test_extension_backend.py fix maybe_initialize_device for custom device. (#121379) 2024-04-09 16:58:52 +00:00
test_foreach.py Enable FX graph caching in another batch of inductor tests (#121697) 2024-03-15 19:38:51 +00:00
test_fp8.py Enable FX graph cache for a batch of inductor tests (#121696) 2024-03-14 03:39:59 +00:00
test_fused_attention.py Made several changes to min-cut partitioner that allow it to recompute more things (#121692) 2024-03-27 22:45:52 +00:00
test_fx_fusion.py [BE]: Optimize min/max/sum comprehensions C419 (#123960) 2024-04-12 23:54:15 +00:00
test_group_batch_fusion.py [Inductor]Fix a couple of broken unit tests (#122714) 2024-03-28 17:44:30 +00:00
test_indexing.py Enable FX graph caching in another batch of inductor tests (#121697) 2024-03-15 19:38:51 +00:00
test_inductor_freezing.py [Easy] Fix freezing bug with mismatched bias sizes (#122724) 2024-03-27 01:41:00 +00:00
test_inductor_utils.py [inductor] Refactor runtime files into torch._inductor.runtime (part 3) (#124557) 2024-04-22 18:46:24 +00:00
test_inplacing_pass.py Enable FX graph cache for a batch of inductor tests (#121696) 2024-03-14 03:39:59 +00:00
test_kernel_benchmark.py Enable FX graph cache for a batch of inductor tests (#121696) 2024-03-14 03:39:59 +00:00
test_layout_optim.py Enable FX graph cache for a batch of inductor tests (#121696) 2024-03-14 03:39:59 +00:00
test_max_autotune.py Fix test_max_autotune_remote_caching (#124655) 2024-04-23 01:41:04 +00:00
test_memory_planning.py Fix memory planning compile error (#123867) 2024-04-12 17:34:58 +00:00
test_metrics.py Enable FX graph cache for a batch of inductor tests (#121696) 2024-03-14 03:39:59 +00:00
test_minifier_isolate.py [Inductor Intel GPU backend Upstream] Add Inductor Intel GPU backend. (#121895) 2024-04-05 09:05:11 +00:00
test_minifier.py [Inductor Intel GPU backend Upstream] Add Inductor Intel GPU backend. (#121895) 2024-04-05 09:05:11 +00:00
test_mkldnn_pattern_matcher.py [Inductor pattern] support int8 woq mm pattern matcher with freezing passe (#122955) 2024-04-09 05:06:52 +00:00
test_mmdecomp.py [dynamo, 3.12] enable tests disabled due to missing dynamo 3.12 support (#123300) 2024-04-05 20:13:17 +00:00
test_move_constructors_to_cuda.py Enable FX graph cache for a batch of inductor tests (#121696) 2024-03-14 03:39:59 +00:00
test_multi_kernel.py [inductor] Fix fresh_inductor_cache() (#122661) 2024-04-15 20:28:54 +00:00
test_pad_mm.py [Inductor] Run pattern matcher over the original graph (#122519) 2024-03-27 22:09:36 +00:00
test_padding.py [inductor] Refactor runtime files into torch._inductor.runtime (part 3) (#124557) 2024-04-22 18:46:24 +00:00
test_pattern_matcher.py Add mm_pattern and bmm_pattern to serialized_patterns (#121313) 2024-04-09 19:42:19 +00:00
test_perf.py [inductor] for UserDefinedTritonKernels don't mark all inputs as mutating (#124425) 2024-04-21 06:00:14 +00:00
test_profiler.py Enhance RecordFunctionFast input args and use input args in triton_heuristics.py (#123459) 2024-04-06 02:44:06 +00:00
test_select_algorithm.py Re-land precompile triton templates (#124030) 2024-04-19 17:03:33 +00:00
test_smoke.py Enable FX graph caching in another batch of inductor tests (#121697) 2024-03-15 19:38:51 +00:00
test_snode_runtime.py Enable FX graph caching in another batch of inductor tests (#121697) 2024-03-15 19:38:51 +00:00
test_split_cat_fx_passes.py [PT2][Inductor][3/n] Customize pre grad and post grad patterns (#121915) 2024-04-03 21:37:21 +00:00
test_standalone_compile.py Enable FX graph cache for a batch of inductor tests (#121696) 2024-03-14 03:39:59 +00:00
test_templated_attention.py Add support for capturing tensors with score_mod (#124444) 2024-04-23 06:20:13 +00:00
test_torchinductor_codegen_dynamic_shapes.py [Inductor] Support custom op in JIT with cpp wrapper (#122554) 2024-03-26 18:48:45 +00:00
test_torchinductor_dynamic_shapes.py Excise uses of the old custom ops APIs (#124134) 2024-04-19 17:56:26 +00:00
test_torchinductor_opinfo.py Add index_reduce decomposition (#122579) 2024-04-18 01:30:47 +00:00
test_torchinductor.py [inductor, test] remove cast for test_tmp_not_defined_issue2_cpu (#114910) 2024-04-22 21:51:53 +00:00
test_triton_extension_backend.py [Inductor] Add a test for creating a cpu inductor-> triton backend (#122396) 2024-03-23 01:14:57 +00:00
test_triton_heuristics.py [inductor] Use compile time config values in runtime (#124561) 2024-04-22 18:46:40 +00:00
test_triton_kernels.py [inductor] for UserDefinedTritonKernels don't mark all inputs as mutating (#124425) 2024-04-21 06:00:14 +00:00
test_triton_wrapper.py Enable FX graph cache for a batch of inductor tests (#121696) 2024-03-14 03:39:59 +00:00
test_unbacked_symints.py [inductor] simplify expr when looking up size hint (#123140) 2024-04-04 04:59:59 +00:00
test_utils.py Enable FX graph cache for a batch of inductor tests (#121696) 2024-03-14 03:39:59 +00:00
test_xpu_basic.py [Inductor Intel GPU backend Upstream] Add Inductor Intel GPU backend. (#121895) 2024-04-05 09:05:11 +00:00