diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index b38f6fc88df..69d51aadd80 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -37,13 +37,9 @@ from torch.testing import FileCheck from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, TEST_MULTIGPU from torch.testing._internal.common_device_type import ( - dtypes, - dtypesIfCUDA, flex_attention_supported_platform as supported_platform, instantiate_device_type_tests, largeTensorTest, - skipCPUIf, - skipCUDAIf, ) from torch.utils._triton import has_triton @@ -94,18 +90,6 @@ def temp_float32_matmul_precision(precision: str): torch.set_float32_matmul_precision(original_precision) -def skip_on_cpu(test_func): - """Decorator to skip tests that are not supported on CPU.""" - decorated_func = skipCPUIf(True, "Not supported on CUDA")(test_func) - return decorated_func - - -def skip_on_cuda(test_func): - """Decorator to skip tests that are not supported on CUDA.""" - decorated_func = skipCUDAIf(True, "Not supported on CUDA")(test_func) - return decorated_func - - def rmse(ref, res): """ Calculate root mean squared error @@ -134,63 +118,39 @@ def create_block_mask_test(score_mod, query, key): return block_mask -@dataclass -class DeviceConfig: - dtypes: list[torch.dtype] - dtypes_fast: list[torch.dtype] - - TEST_ON_CUDA = ( torch.cuda.is_available() and torch.utils._triton.has_triton() and torch.cuda.get_device_capability() >= (8, 0) ) -device_configs = {} -test_device = ("cpu", "cuda") - - -class SubstringSet: - def __init__(self, items): - self.items = set(items) - - def __contains__(self, item): - if "cuda" in item: - item = "cuda" - return item in self.items - - -DEVICE_SUPPORTS_BACKWARDS = SubstringSet( - [ - "cuda", - ] -) - -device_configs["cuda"] = DeviceConfig( - dtypes=( +if TEST_ON_CUDA: + test_device = ("cuda",) + test_dtypes = ( [torch.float32, torch.bfloat16, torch.float16] if PLATFORM_SUPPORTS_BF16 else [torch.float16, torch.float32] - ), - dtypes_fast=[torch.float16], -) -device_configs["cpu"] = DeviceConfig( - dtypes=( + ) + test_dtypes_fast = [torch.float16] + SKIP_UT_ON_CPU = False +else: + test_device = ("cpu",) + torch_config_string = torch.__config__.show() + # training and some corner cases are not supported on cpu and will be skiped + SKIP_UT_ON_CPU = True + LONG_COMPILATION_ON_CPU = False + if "CLANG" in torch_config_string.upper(): + # if the compiler is clang, skip UT for CPU due to long compilation time found in CI + # TODO: check reason of long compile time + LONG_COMPILATION_ON_CPU = True + + test_dtypes = ( [torch.float32, torch.bfloat16, torch.float16] if torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported() else [torch.float32] - ), - dtypes_fast=[torch.float32], -) - -torch_config_string = torch.__config__.show() -LONG_COMPILATION_ON_CPU = False - -if "CLANG" in torch_config_string.upper(): - # if the compiler is clang, skip UT for CPU due to long compilation time found in CI - # TODO: check reason of long compile time - LONG_COMPILATION_ON_CPU = True + ) + test_dtypes_fast = [torch.float32] # --------- Useful score mod functions for testing --------- @@ -252,9 +212,9 @@ def _squared(score, b, h, m, n): return score * score -def _head_offset(dtype: torch.dtype, device: str): +def _head_offset(dtype: torch.dtype): """Captured Buffer""" - head_offset = torch.rand(H, device=device, dtype=dtype) + head_offset = torch.rand(H, device="cuda", dtype=dtype) def score_mod(score, b, h, m, n): return score * head_offset[h] @@ -320,9 +280,9 @@ captured_buffers_map = { "_head_offset": _head_offset, } -B = 2 -H = 4 -S = 256 +B = 4 +H = 8 +S = 2048 D = 64 test_Hq_Hkv = [ @@ -362,7 +322,7 @@ def query_key_value_clones( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype = None, ): """Clones the query, key, and value tensors and moves them to the specified dtype.""" if dtype is None: @@ -386,10 +346,13 @@ def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor): class TestFlexAttention(InductorTestCase): def setUp(self): super().setUp() - skipCPUIf( - LONG_COMPILATION_ON_CPU, - "skip UT for CPU due to long compilation time found in CI", - ) + self.test_inference_only = False + if test_device[0] == "cpu": + if LONG_COMPILATION_ON_CPU: + self.skipTest( + "skip UT for CPU due to long compilation time found in CI" + ) + self.test_inference_only = True def _check_equal( self, @@ -475,8 +438,7 @@ class TestFlexAttention(InductorTestCase): def run_test( self, score_mod: _score_mod_signature, - dtype: torch.dtype, - device: str, + dtype: torch.dtype = torch.float16, Q_B: int = B, Q_H: int = H, Q_S: int = S, @@ -486,8 +448,8 @@ class TestFlexAttention(InductorTestCase): KV_S: Optional[int] = None, V_D: Optional[int] = None, block_mask: Optional[BlockMask] = None, + device="cuda", ): - requires_grad = device in DEVICE_SUPPORTS_BACKWARDS if KV_B is None: KV_B = Q_B if KV_H is None: @@ -500,24 +462,23 @@ class TestFlexAttention(InductorTestCase): if device == "cpu" and dtype is torch.float16: dtype = torch.float32 - requires_grad = device in DEVICE_SUPPORTS_BACKWARDS q = torch.randn( (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) k = torch.randn( (KV_B, KV_H, KV_S, Q_D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) v = torch.randn( (KV_B, KV_H, KV_S, V_D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) if block_mask is None: block_mask = create_block_mask( @@ -533,12 +494,7 @@ class TestFlexAttention(InductorTestCase): golden_out = sdpa_partial(q_gold, k_gold, v_gold) ref_out = sdpa_partial(q_ref, k_ref, v_ref) compiled_out = compiled_sdpa(q, k, v) - - assert isinstance(golden_out, torch.Tensor) - assert isinstance(ref_out, torch.Tensor) - assert isinstance(compiled_out, torch.Tensor) - - if not requires_grad: + if self.test_inference_only: self._check_out( golden_out, ref_out, @@ -576,9 +532,9 @@ class TestFlexAttention(InductorTestCase): k: Tensor, v: Tensor, block_mask, - dtype: torch.dtype, - device: str, + dtype: torch.dtype = torch.float16, page_size: int = 128, + device="cuda", ) -> tuple[Tensor, Tensor, BlockMask, _score_mod_signature]: assert block_mask is not None, "Must provide block_mask" Q_B, Q_H, Q_S, _ = q.shape @@ -658,9 +614,9 @@ class TestFlexAttention(InductorTestCase): q: Tensor, k: Tensor, v: Tensor, - dtype: torch.dtype, - device: str, + dtype: torch.dtype = torch.float16, block_mask: Optional[BlockMask] = None, + device="cuda", ) -> tuple[Tensor, Tensor]: B, Q_H, Q_S, KV_H, KV_S = ( q.shape[0], @@ -679,16 +635,17 @@ class TestFlexAttention(InductorTestCase): converted_block_mask, converted_score_mod, ) = self.preprocess_paged_attention( - score_mod, q, k, v, block_mask, dtype, device, block_mask.BLOCK_SIZE[1] + score_mod, q, k, v, block_mask, dtype, block_mask.BLOCK_SIZE[1], device ) compiled_sdpa = torch.compile(flex_attention) # compute return_lse = True - requires_grad = device in DEVICE_SUPPORTS_BACKWARDS - if requires_grad: - compiled_out, compiled_lse = compiled_sdpa( + if self.test_inference_only: + return_lse = False + compiled_lse = None + compiled_out = compiled_sdpa( q, k_cache, v_cache, @@ -697,10 +654,9 @@ class TestFlexAttention(InductorTestCase): score_mod=converted_score_mod, enable_gqa=(not Q_H == KV_H), ) + else: - return_lse = False - compiled_lse = None - compiled_out = compiled_sdpa( + compiled_out, compiled_lse = compiled_sdpa( q, k_cache, v_cache, @@ -713,9 +669,8 @@ class TestFlexAttention(InductorTestCase): def run_test_with_paged_attention( self, - score_mod: Optional[Callable], - dtype: torch.dtype, - device, + score_mod: Optional[Callable] = _identity, + dtype: torch.dtype = torch.float16, Q_B: int = B, Q_H: int = H, Q_S: int = S, @@ -725,6 +680,7 @@ class TestFlexAttention(InductorTestCase): KV_S: int = S, V_D: int = D, block_mask: Optional[BlockMask] = None, + device="cuda", ): assert Q_H % KV_H == 0 if device == "cpu" and dtype is torch.float16: @@ -758,7 +714,7 @@ class TestFlexAttention(InductorTestCase): ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True) compiled_out, compiled_lse = self.run_paged_attention( - score_mod, q, k, v, dtype, device, block_mask + score_mod, q, k, v, dtype, block_mask, device ) self._check_out( golden_out, @@ -766,8 +722,8 @@ class TestFlexAttention(InductorTestCase): compiled_out, is_paged_attention=True, ) - requires_grad = device in DEVICE_SUPPORTS_BACKWARDS - if requires_grad: + + if not self.test_inference_only: self._check_out( golden_lse, ref_lse, @@ -778,8 +734,7 @@ class TestFlexAttention(InductorTestCase): def run_test_with_call( self, sdpa_call: Callable, - dtype: torch.dtype, - device: str, + dtype: torch.dtype = torch.float16, Q_B: int = B, Q_H: int = H, Q_S: int = S, @@ -788,29 +743,27 @@ class TestFlexAttention(InductorTestCase): KV_H: int = H, KV_S: int = S, V_D: int = D, + device="cuda", ): if device == "cpu" and dtype is torch.float16: dtype = torch.float32 - - requires_grad = device in DEVICE_SUPPORTS_BACKWARDS - q = torch.randn( (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) k = torch.randn( (KV_B, KV_H, KV_S, Q_D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) v = torch.randn( (KV_B, KV_H, KV_S, V_D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) @@ -818,7 +771,7 @@ class TestFlexAttention(InductorTestCase): golden_out = sdpa_call(q_gold, k_gold, v_gold) ref_out = sdpa_call(q_ref, k_ref, v_ref) compiled_out = compiled_sdpa(q, k, v) - if not requires_grad: + if self.test_inference_only: self._check_out( golden_out, ref_out, @@ -852,12 +805,12 @@ class TestFlexAttention(InductorTestCase): def run_dynamic_test( self, score_mask_mod: tuple[Callable, Callable], - dtype: torch.dtype, - device, + dtype: torch.dtype = torch.float16, B: int = B, H: int = H, S: int = S, D: int = D, + device="cuda", ): if device == "cpu" and dtype is torch.float16: dtype = torch.float32 @@ -868,32 +821,30 @@ class TestFlexAttention(InductorTestCase): block_mask1 = create_block_mask(mask_mod, 1, 1, S, S, device=device) sdpa_partial1 = create_attention(score_mod, block_mask=block_mask1) - requires_grad = device in DEVICE_SUPPORTS_BACKWARDS - q1 = torch.randn( (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) k1 = torch.randn( (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) v1 = torch.randn( (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) q1_ref, k1_ref, v1_ref = query_key_value_clones(q1, k1, v1) q1_gold, k1_gold, v1_gold = query_key_value_clones(q1, k1, v1, torch.float64) ref_out1 = sdpa_partial1(q1_ref, k1_ref, v1_ref) golden_out1 = sdpa_partial1(q1_gold, k1_gold, v1_gold) - if requires_grad: + if not self.test_inference_only: backward_grad1 = torch.randn((B, H, S, D), dtype=dtype, device=device) golden_out1.backward(backward_grad1.to(torch.float64)) ref_out1.backward(backward_grad1) @@ -908,26 +859,26 @@ class TestFlexAttention(InductorTestCase): (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) k2 = torch.randn( (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) v2 = torch.randn( (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) q2_ref, k2_ref, v2_ref = query_key_value_clones(q2, k2, v2) q2_gold, k2_gold, v2_gold = query_key_value_clones(q2, k2, v2, torch.float64) ref_out2 = sdpa_partial2(q2_ref, k2_ref, v2_ref) golden_out2 = sdpa_partial2(q2_gold, k2_gold, v2_gold) - if requires_grad: + if not self.test_inference_only: backward_grad2 = torch.randn((B, H, S, D), dtype=dtype, device=device) golden_out2.backward(backward_grad2.to(torch.float64)) ref_out2.backward(backward_grad2) @@ -941,26 +892,26 @@ class TestFlexAttention(InductorTestCase): (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) k3 = torch.randn( (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) v3 = torch.randn( (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) q3_ref, k3_ref, v3_ref = query_key_value_clones(q3, k3, v3) q3_gold, k3_gold, v3_gold = query_key_value_clones(q3, k3, v3, torch.float64) ref_out3 = sdpa_partial3(q3_ref, k3_ref, v3_ref) golden_out3 = sdpa_partial3(q3_gold, k3_gold, v3_gold) - if requires_grad: + if not self.test_inference_only: backward_grad3 = torch.randn((B, H, S, D), dtype=dtype, device=device) golden_out3.backward(backward_grad3.to(torch.float64)) ref_out3.backward(backward_grad3) @@ -973,7 +924,7 @@ class TestFlexAttention(InductorTestCase): compiled_sdpa1 = torch.compile(sdpa_partial1, backend=backend, dynamic=True) compiled_out1 = compiled_sdpa1(q1, k1, v1) - if requires_grad: + if not self.test_inference_only: compiled_out1.backward(backward_grad1) self._check_out_and_grad( @@ -998,7 +949,7 @@ class TestFlexAttention(InductorTestCase): compiled_sdpa2 = torch.compile(sdpa_partial2, backend=backend, dynamic=True) compiled_out2 = compiled_sdpa2(q2, k2, v2) - if requires_grad: + if not self.test_inference_only: compiled_out2.backward(backward_grad2) self._check_out_and_grad( @@ -1023,7 +974,7 @@ class TestFlexAttention(InductorTestCase): compiled_sdpa3 = torch.compile(sdpa_partial3, backend=backend, dynamic=True) compiled_out3 = compiled_sdpa3(q3, k3, v3) - if requires_grad: + if not self.test_inference_only: compiled_out3.backward(backward_grad3) self._check_out_and_grad( @@ -1047,12 +998,12 @@ class TestFlexAttention(InductorTestCase): def run_automatic_dynamic_test( self, score_mod: Callable, - dtype: torch.dtype, - device: str, + dtype: torch.dtype = torch.float16, B: int = B, H: int = H, S: int = S, D: int = D, + device="cuda", ): if device == "cpu" and dtype is torch.float16: dtype = torch.float32 @@ -1060,25 +1011,23 @@ class TestFlexAttention(InductorTestCase): block_mask1 = create_block_mask(noop_mask, 1, 1, S, S, device=device) sdpa_partial1 = create_attention(score_mod, block_mask=block_mask1) # The first eager batch, shape (B, H, S, D) - requires_grad = device in DEVICE_SUPPORTS_BACKWARDS - q1 = torch.randn( (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) k1 = torch.randn( (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) v1 = torch.randn( (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) golden_out1 = sdpa_partial1( q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64) @@ -1094,19 +1043,19 @@ class TestFlexAttention(InductorTestCase): (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) k2 = torch.randn( (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) v2 = torch.randn( (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) golden_out2 = sdpa_partial2( q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64) @@ -1122,19 +1071,19 @@ class TestFlexAttention(InductorTestCase): (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) k3 = torch.randn( (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) v3 = torch.randn( (B, H, S, D), dtype=dtype, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) golden_out3 = sdpa_partial3( q3.to(torch.float64), k3.to(torch.float64), v3.to(torch.float64) @@ -1157,40 +1106,32 @@ class TestFlexAttention(InductorTestCase): # The first batch. backend = torch._dynamo.testing.CompileCounterWithBackend("inductor") - compiled_out1 = torch.compile(sdpa_partial1, backend=backend, fullgraph=True)( - q1, k1, v1 - ) + compiled_out1 = torch.compile(sdpa_partial1, backend=backend)(q1, k1, v1) self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor) self.assertEqual(backend.frame_count, 1) # The second batch (automatic dynamic). - compiled_out2 = torch.compile(sdpa_partial2, backend=backend, fullgraph=True)( - q2, k2, v2 - ) + compiled_out2 = torch.compile(sdpa_partial2, backend=backend)(q2, k2, v2) self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor) self.assertEqual(backend.frame_count, 2) # The third batch (no re-compilation). - compiled_out3 = torch.compile(sdpa_partial3, backend=backend, fullgraph=True)( - q3, k3, v3 - ) + compiled_out3 = torch.compile(sdpa_partial3, backend=backend)(q3, k3, v3) self._check_equal(golden_out3, ref_out3, compiled_out3, fudge_factor) self.assertEqual(backend.frame_count, 2) @supported_platform - @dtypes(*device_configs["cpu"].dtypes) - @dtypesIfCUDA(*device_configs["cuda"].dtypes) + @common_utils.parametrize("dtype", test_dtypes) @common_utils.parametrize("score_mod", test_score_mods) - def test_builtin_score_mods(self, device, dtype, score_mod: Callable): + def test_builtin_score_mods(self, device, dtype: torch.dtype, score_mod: Callable): self.run_test(score_mod, dtype, device=device) self.run_test_with_paged_attention(score_mod, dtype, device=device) @running_on_a100_only + @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) def test_builtin_score_mods_seqlen_lt_default_sparse_block_size( - self, device, dtype, score_mod: Callable + self, device, dtype: torch.dtype, score_mod: Callable ): # _DEFAULT_SPARSE_BLOCK_SIZE is 128 attention = functools.partial( @@ -1198,11 +1139,12 @@ class TestFlexAttention(InductorTestCase): score_mod=score_mod, kernel_options={"FORCE_USE_FLEX_ATTENTION": True}, ) - self.run_test_with_call(attention, dtype, device, B, H, 64, D, B, H, 64, D) + self.run_test_with_call( + attention, dtype, B, H, 64, D, B, H, 64, D, device=device + ) @running_on_a100_only - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_seqlen_lt_custom_sparse_block_size( self, device, dtype: torch.dtype, score_mod: Callable @@ -1220,40 +1162,27 @@ class TestFlexAttention(InductorTestCase): kernel_options={"FORCE_USE_FLEX_ATTENTION": True}, ) self.run_test_with_call( - attention, - dtype, - device, - B, - H, - 64, - D, - B, - H, - 64, - D, + attention, dtype, B, H, 64, D, B, H, 64, D, device=device ) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mask_mod", test_score_mask_mod_map.items()) def test_builtin_score_mods_dynamic( self, device, dtype: torch.dtype, score_mask_mod: tuple[Callable, Callable] ): - self.run_dynamic_test(score_mask_mod, dtype, S=1024, device=device) + self.run_dynamic_test(score_mask_mod, dtype, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_automatic_dynamic( self, device, dtype: torch.dtype, score_mod: Callable ): - self.run_automatic_dynamic_test(score_mod, dtype, S=1024, device=device) + self.run_automatic_dynamic_test(score_mod, dtype, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_different_seqlen( self, device, dtype: torch.dtype, score_mod: Callable @@ -1261,7 +1190,6 @@ class TestFlexAttention(InductorTestCase): inputs = ( score_mod, dtype, - device, B, H, S // 2, # Seqlen of Q is different from seqlen of K/V @@ -1271,12 +1199,11 @@ class TestFlexAttention(InductorTestCase): S, D, ) - self.run_test(*inputs) - self.run_test_with_paged_attention(*inputs) + self.run_test(*inputs, device=device) + self.run_test_with_paged_attention(*inputs, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes) - @dtypesIfCUDA(*device_configs["cuda"].dtypes) + @common_utils.parametrize("dtype", test_dtypes) @common_utils.parametrize("score_mod", test_score_mods) @common_utils.parametrize("BLOCK_SIZE", test_block_size) def test_builtin_score_mods_different_block_size( @@ -1295,8 +1222,7 @@ class TestFlexAttention(InductorTestCase): ) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("batch_dims", test_Bq_Bkv) @common_utils.parametrize("head_dims", test_Hq_Hkv) @common_utils.parametrize("score_mod", test_score_mods) @@ -1317,11 +1243,10 @@ class TestFlexAttention(InductorTestCase): block_mask = create_block_mask(noop_mask, Bq, 1, S, S, device=device) self.run_test( - score_mod, dtype, device, Bq, Hq, S, D, Bkv, Hkv, S, D, block_mask + score_mod, dtype, Bq, Hq, S, D, Bkv, Hkv, S, D, block_mask, device=device ) @supported_platform - @skip_on_cpu def test_small_block_mask(self, device): compiled_create_block_mask = torch.compile(create_block_mask) @@ -1365,8 +1290,7 @@ class TestFlexAttention(InductorTestCase): create_block_mask_from_seqlens(seqlen, seqlen) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("batch_dims", test_Bq_Bkv) @common_utils.parametrize("head_dims", test_Hq_Hkv) @common_utils.parametrize("score_mod", test_score_mods) @@ -1392,17 +1316,17 @@ class TestFlexAttention(InductorTestCase): flex_attention, block_mask=block_mask, enable_gqa=(not Hq == Hkv) ) - self.run_test_with_call(attention, dtype, device, Bq, Hq, S, D, Bkv, Hkv, S, D) + self.run_test_with_call( + attention, dtype, Bq, Hq, S, D, Bkv, Hkv, S, D, device=device + ) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_GQA(self, device, dtype: torch.dtype, score_mod: Callable): inputs = ( score_mod, dtype, - device, B, H * 4, # Hq = 4*Hkv. S // 8, @@ -1412,12 +1336,11 @@ class TestFlexAttention(InductorTestCase): S, D, ) - self.run_test(*inputs) - self.run_test_with_paged_attention(*inputs) + self.run_test(*inputs, device=device) + self.run_test_with_paged_attention(*inputs, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize( "q_s", test_strides[:2] ) # TODO: fix layout for query braodcasting @@ -1444,15 +1367,13 @@ class TestFlexAttention(InductorTestCase): v_shape = (B, H, S, D) do_shape = (B, H, S // 2, D) - requires_grad = device in DEVICE_SUPPORTS_BACKWARDS - def coerce_to_strides(val, shape, strides): strides, offset = strides val_max = [x * (y - 1) for x, y in zip(strides, shape)] assert sum(val_max) + offset < B * H * S * D * 2 assert strides[-1] == 1 return torch.as_strided(val, shape, strides, offset).requires_grad_( - requires_grad + not self.test_inference_only ) q = coerce_to_strides(q1, q_shape, q_s) @@ -1463,7 +1384,7 @@ class TestFlexAttention(InductorTestCase): block_mask = _create_empty_block_mask(q, k) score_mod = _generate_alibi_bias(8) sdpa_partial = create_attention(score_mod=score_mod, block_mask=block_mask) - compiled_sdpa = torch.compile(sdpa_partial, fullgraph=True) + compiled_sdpa = torch.compile(sdpa_partial) ref_out = sdpa_partial(q, k, v) compiled_out = compiled_sdpa(q, k, v) @@ -1471,7 +1392,7 @@ class TestFlexAttention(InductorTestCase): torch.testing.assert_close( ref_out, compiled_out, atol=tolerance.atol, rtol=tolerance.rtol ) - if requires_grad: + if not self.test_inference_only: ref_out.backward(do) ref_grads = [q.grad, k.grad, v.grad] q.grad = None @@ -1559,8 +1480,7 @@ class TestFlexAttention(InductorTestCase): self.run_test_with_paged_attention(index_weird2, torch.float16, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes) - @dtypesIfCUDA(*device_configs["cuda"].dtypes) + @common_utils.parametrize("dtype", test_dtypes) def test_skip_odd_keys(self, device, dtype: torch.dtype): def score_mod(score, b, h, q, kv): return torch.where(kv % 2 == 0, score, float("-inf")) @@ -1569,8 +1489,7 @@ class TestFlexAttention(InductorTestCase): self.run_test_with_paged_attention(score_mod, dtype, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes) - @dtypesIfCUDA(*device_configs["cuda"].dtypes) + @common_utils.parametrize("dtype", test_dtypes) def test_function_composition(self, device, dtype: torch.dtype): def score_mod_1(score, b, h, m, n): return score + (m - n) @@ -1585,8 +1504,7 @@ class TestFlexAttention(InductorTestCase): self.run_test_with_paged_attention(composed_score_mod, dtype, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes) - @dtypesIfCUDA(*device_configs["cuda"].dtypes) + @common_utils.parametrize("dtype", test_dtypes) def test_captured_buffers_all_dims(self, device, dtype: torch.dtype): head_scale = torch.randn(H, device=device) batch_scale = torch.randn(B, device=device) @@ -1602,8 +1520,7 @@ class TestFlexAttention(InductorTestCase): self.run_test_with_paged_attention(all_bias, dtype, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) def test_seq_masking(self, device, dtype): seq_idx = torch.zeros(S, device=device, dtype=torch.bool) seq_idx[S // 2 :] = 1 @@ -1615,8 +1532,7 @@ class TestFlexAttention(InductorTestCase): self.run_test_with_paged_attention(seq_mask_mod, dtype, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) def test_load_from_bias_seq_only(self, device, dtype): bias = torch.randn(S, S, device=device, dtype=dtype) @@ -1627,8 +1543,7 @@ class TestFlexAttention(InductorTestCase): self.run_test_with_paged_attention(bias_mod, dtype, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) def test_load_from_bias_seq_batch(self, device, dtype): bias = torch.randn(B, S, S, device=device, dtype=dtype) @@ -1639,9 +1554,10 @@ class TestFlexAttention(InductorTestCase): self.run_test_with_paged_attention(bias_mod, dtype, device=device) @supported_platform - @skip_on_cpu - def test_load_from_view_buffer(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_load_from_view_buffer(self): dtype = torch.float16 + device = "cuda" W = 8 class SimpleAttention(torch.nn.Module): @@ -1686,8 +1602,7 @@ class TestFlexAttention(InductorTestCase): out.sum().backward() @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) def test_load_from_bias_head_seq_batch(self, device, dtype): bias = torch.randn(B, H, S, S, device=device, dtype=dtype) @@ -1698,8 +1613,7 @@ class TestFlexAttention(InductorTestCase): self.run_test_with_paged_attention(bias_mod, dtype, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) def test_load_rel_bias(self, device, dtype): rel_bias = torch.randn(2 * S, device=device, dtype=dtype) @@ -1710,8 +1624,7 @@ class TestFlexAttention(InductorTestCase): self.run_test_with_paged_attention(bias_mod, dtype, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) def test_dependent_causal_bidirectional(self, device, dtype): num_bidirectional = torch.randint(0, S, (B,), device=device, dtype=torch.int32) @@ -1731,8 +1644,7 @@ class TestFlexAttention(InductorTestCase): self.run_test_with_paged_attention(bias_mod, dtype, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) def test_natten_2d(self, device, dtype): H = 32 W = S // H @@ -1756,7 +1668,8 @@ class TestFlexAttention(InductorTestCase): self.run_test_with_paged_attention(natten_mask, dtype, device=device) @supported_platform - def test_subgraph_respect_decompostion(self, device): + @common_utils.parametrize("dtype", test_dtypes_fast) + def test_subgraph_respect_decompostion(self, device, dtype): from torch._decomp import core_aten_decompositions from torch.fx.experimental.proxy_tensor import make_fx @@ -1799,8 +1712,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) def test_silu_on_score(self, device, dtype): def silu_score(score, b, h, q, kv): return torch.nn.functional.silu(score) @@ -1809,8 +1721,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): self.run_test_with_paged_attention(silu_score, dtype, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) def test_padded_dense_causal(self, device, dtype): seq_len = torch.arange(B, device=device, dtype=torch.int32) + 1 @@ -1827,8 +1738,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): self.run_test(causal_njt, dtype, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) def test_captured_scale(self, device, dtype): scale = torch.ones((), device=device, dtype=torch.int32) @@ -1839,8 +1749,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): self.run_test_with_paged_attention(score_mod_scale, dtype, device=device) @supported_platform - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) def test_recompile_changed_score_mod(self, device, dtype): scale = torch.ones((), device=device, dtype=torch.int32) ADD = True @@ -1860,8 +1769,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @supported_platform @expectedFailure # If we capture a tensor then we can perform a reduction on it, and that shouldn't be allowed - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) def test_captured_reduction(self, device, dtype): scale = torch.randn((B, 8), device=device) @@ -1898,16 +1806,18 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol) @supported_platform - @skip_on_cpu - def test_multiple_mask_calls(self, device): - make_tensor = functools.partial( - torch.randn, - (1, 4, 512, 64), - dtype=torch.float32, - device=device, - requires_grad=True, + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_multiple_mask_calls(self): + # Create inputs + query = torch.randn( + (1, 4, 512, 64), dtype=torch.float32, device="cuda", requires_grad=True + ) + key = torch.randn( + (1, 4, 512, 64), dtype=torch.float32, device="cuda", requires_grad=True + ) + value = torch.randn( + (1, 4, 512, 64), dtype=torch.float32, device="cuda", requires_grad=True ) - query, key, value = make_tensor(), make_tensor(), make_tensor() window_size = 32 @@ -1966,7 +1876,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): return flex_attention(q3, k3, v3, score_mod=scoremod_1) out = f(query, *keys, *values) - out2 = torch.compile(f, fullgraph=True)(query, *keys, *values) + out2 = torch.compile(f)(query, *keys, *values) self.assertTrue((out - out2).abs().mean() < 1e-2) @supported_platform @@ -2040,7 +1950,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): block_mask=converted_block_mask2, ) - compiled_out = torch.compile(paged_f, fullgraph=True)( + compiled_out = torch.compile(paged_f)( query, k_cache1, k_cache2, v_cache1, v_cache2 ) tolerance = Tolerances(atol=2e-1, rtol=2e-1) @@ -2142,7 +2052,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): block_mask=converted_block_mask3, ) - compiled_out = torch.compile(paged_f, fullgraph=True)( + compiled_out = torch.compile(paged_f)( query, k_cache1, k_cache2, k_cache3, v_cache1, v_cache2, v_cache3 ) tolerance = Tolerances(atol=2e-1, rtol=2e-1) @@ -2151,10 +2061,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ) @supported_platform - @skip_on_cpu - def test_inputs_are_realized(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_inputs_are_realized(self): def f(q, k, v): - x = torch.randn(1024, device=device) + x = torch.randn(1024, device="cuda") x = x * 2 def func(qk, b, h, q, kv): @@ -2163,7 +2073,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): return flex_attention(q.sin(), k, v, score_mod=func).cos() q, k, v = ( - torch.randn(1, 8, 1024, 64, device=device, requires_grad=True) + torch.randn(1, 8, 1024, 64, device="cuda", requires_grad=True) for _ in range(3) ) ref = f(q, k, v) @@ -2177,12 +2087,11 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): self.assertTrue((ref - out).abs().mean() < 1e-2) @supported_platform - @skip_on_cpu def test_make_block_mask(self, device): def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx - block_mask_a = torch.compile(create_block_mask, fullgraph=True)( + block_mask_a = torch.compile(create_block_mask)( causal_mask, 1, 1, 512, 512, device=device ) block_mask_b = create_block_mask(causal_mask, 1, 1, 512, 512, device=device) @@ -2201,53 +2110,33 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): def sliding_window(b, h, q, kv): return (q - kv) <= 512 - local_s = 2048 block_mask = create_block_mask( - and_masks(causal_mask, sliding_window), - 1, - 1, - local_s, - local_s, - device=device, + and_masks(causal_mask, sliding_window), 1, 1, S, S, device=device ) self.assertExpectedInline(block_mask.kv_num_blocks.sum().item(), """28""") attention = functools.partial(flex_attention, block_mask=block_mask) - self.run_test_with_call( - attention, Q_S=local_s, KV_S=local_s, dtype=torch.float16, device=device - ) + self.run_test_with_call(attention, device=device) block_mask = create_block_mask( - and_masks(causal_mask, neg_causal_mask), - 1, - 1, - local_s, - local_s, - device=device, + and_masks(causal_mask, neg_causal_mask), 1, 1, S, S, device=device ) self.assertEqual(block_mask.kv_num_blocks.sum(), 0) block_mask1 = create_block_mask( - or_masks(causal_mask, neg_causal_mask), - 1, - 1, - local_s, - local_s, - device=device, - ) - block_mask2 = create_block_mask( - noop_mask, 1, 1, local_s, local_s, device=device + or_masks(causal_mask, neg_causal_mask), 1, 1, S, S, device=device ) + block_mask2 = create_block_mask(noop_mask, 1, 1, S, S, device=device) self.assertEqual(block_mask1.sparsity(), block_mask2.sparsity()) @supported_platform - @skip_on_cpu - def test_epilogue_fused(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_epilogue_fused(self): @torch.compile def f(q, k, v): out = flex_attention(q, k, v) return out.cos() - q, k, v = (torch.randn(1, 8, 1024, 64, device=device) for _ in range(3)) + q, k, v = (torch.randn(1, 8, 1024, 64, device="cuda") for _ in range(3)) metrics.reset() _, code = run_and_get_code(f, q, k, v) fc = FileCheck() @@ -2262,8 +2151,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): self.assertLess(metrics.num_bytes_accessed, accessed_bytes * num_accesses) @supported_platform - @dtypes(*device_configs["cpu"].dtypes) - @dtypesIfCUDA(*device_configs["cuda"].dtypes) + @common_utils.parametrize("dtype", test_dtypes) def test_njt_causal(self, device, dtype): offsets = torch.tensor( [0, 1024, 1024 + 512, S], device=device, dtype=torch.int32 @@ -2301,18 +2189,16 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): def score_mod(score, b, h, m, n): return score * 2 - self.run_test(score_mod, dtype=torch.float16, device=device) - self.run_test_with_paged_attention( - score_mod, dtype=torch.float16, device=device - ) + self.run_test(score_mod, device=device) + self.run_test_with_paged_attention(score_mod, device=device) @supported_platform @skip("TODO: Figure out why this is erroring") @patch.object(torch._inductor.config, "max_autotune", True) - def test_max_autotune_with_captured(self, device): - head_scale = torch.randn(H, device=device) - batch_scale = torch.randn(B, device=device) - tok_scale = torch.randn(S, device=device) + def test_max_autotune_with_captured(self): + head_scale = torch.randn(H, device="cuda") + batch_scale = torch.randn(B, device="cuda") + tok_scale = torch.randn(S, device="cuda") def bias_mod(score, batch, head, token_q, token_kv): score = score + tok_scale[token_q] @@ -2320,23 +2206,22 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): score = score + head_scale[head] return score - self.run_test(bias_mod, dtype=torch.float32, device=device) + self.run_test(bias_mod) @supported_platform @common_utils.parametrize("score_mod", test_score_mods) - @dtypes(*device_configs["cpu"].dtypes) - @dtypesIfCUDA(*device_configs["cuda"].dtypes) + @common_utils.parametrize("dtype", test_dtypes) @common_utils.parametrize("head_dims", [(D, D // 2), (D // 2, D)]) def test_non_equal_head_dims(self, device, dtype, score_mod, head_dims): qk_d, v_d = head_dims - self.run_test(score_mod, dtype, device, B, H, S, qk_d, B, H, S, V_D=v_d) + self.run_test(score_mod, dtype, B, H, S, qk_d, B, H, S, V_D=v_d, device=device) self.run_test_with_paged_attention( - score_mod, dtype, device, B, H, S, qk_d, B, H, S, V_D=v_d + score_mod, dtype, B, H, S, qk_d, B, H, S, V_D=v_d, device=device ) @supported_platform - @skip_on_cpu - def test_autograd_function_in_score_mod(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_autograd_function_in_score_mod(self): class ApplyMask(torch.autograd.Function): generate_vmap_rule = True @@ -2359,7 +2244,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): func = torch.compile(flex_attention, fullgraph=True) q, k, v = ( - torch.randn(1, 8, 1024, 64, device=device, requires_grad=True) + torch.randn(1, 8, 1024, 64, device="cuda", requires_grad=True) for _ in range(3) ) @@ -2378,7 +2263,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): block_mask = create_block_mask(mask_mod, 1, 1, S, S, device=device) attention = functools.partial(flex_attention, block_mask=block_mask) - self.run_test_with_call(attention, dtype=torch.float16, device=device) + self.run_test_with_call(attention, device=device) @supported_platform def test_causal_block_paged_attention(self, device): @@ -2387,10 +2272,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): block_mask = create_block_mask(mask_mod, B, 1, S, S, device=device) self.run_test_with_paged_attention( - score_mod=_identity, - dtype=torch.float16, - device=device, - block_mask=block_mask, + score_mod=_identity, block_mask=block_mask, device=device ) @supported_platform @@ -2411,16 +2293,15 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): block_mask = create_block_mask( causal, B=4, H=None, Q_LEN=S, KV_LEN=S, device=device ) - torch.compile(flex_attention, fullgraph=True)( - q, k, v, score_mod, block_mask=block_mask - ) + torch.compile(flex_attention)(q, k, v, score_mod, block_mask=block_mask) @supported_platform @common_utils.parametrize("head_dim", [17, 24, 94, 121]) - @dtypes(*device_configs["cpu"].dtypes_fast) - @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) + @common_utils.parametrize("dtype", test_dtypes_fast) def test_non_pow_2_headdim(self, device, dtype, head_dim): - self.run_test(_rel_bias, dtype, device, B, H, S, head_dim, B, H, S, head_dim) + self.run_test( + _rel_bias, dtype, B, H, S, head_dim, B, H, S, head_dim, device=device + ) @supported_platform def test_GQA_causal_mask(self, device): @@ -2435,7 +2316,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): self.run_test_with_call( attention, torch.float16, - device, B, H * 4, # Hq = 4*Hkv. S // 8, @@ -2444,17 +2324,16 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): H, S // 8, D, + device=device, ) self.run_test_with_paged_attention( - _identity, - dtype=torch.float16, - device=device, Q_H=H * 4, Q_S=S // 8, KV_H=H, KV_S=S // 8, block_mask=block_mask, + device=device, ) @supported_platform @@ -2480,16 +2359,15 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): self.assertEqual(auto_mask.to_dense(), manual_mask.to_dense()) @supported_platform - @skip_on_cpu - @dtypes(*device_configs["cpu"].dtypes) - @dtypesIfCUDA(*device_configs["cuda"].dtypes) + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + @common_utils.parametrize("dtype", test_dtypes) @common_utils.parametrize("score_mod", [_identity, _causal]) - def test_logsumexp_correctness(self, device, dtype, score_mod): + def test_logsumexp_correctness(self, dtype, score_mod): make_tensor = functools.partial( torch.randn, (B, H, S, D), dtype=dtype, - device=device, + device="cuda", requires_grad=True, ) q, k, v = make_tensor(), make_tensor(), make_tensor() @@ -2528,13 +2406,13 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ) @supported_platform - @skip_on_cpu - def test_logsumexp_only_return(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_logsumexp_only_return(self): make_tensor = functools.partial( torch.randn, (B, H, S, D), dtype=torch.float32, - device=device, + device="cuda", requires_grad=True, ) q, k, v = make_tensor(), make_tensor(), make_tensor() @@ -2552,15 +2430,15 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ) @supported_platform - @skip_on_cpu + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") @common_utils.parametrize( "score_mod", [_identity, _causal, _times_two, _squared, _trig, _trig2] ) - def test_aot_eager_gradcheck(self, device, score_mod): + def test_aot_eager_gradcheck(self, score_mod): make_tensor = functools.partial( torch.randn, (2, 2, 11, 4), - device=device, + device="cuda", dtype=torch.float64, requires_grad=True, ) @@ -2575,7 +2453,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ) @supported_platform - @skip_on_cpu + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") def test_eager_backward_strides(self): class Repro(torch.nn.Module): def __init__(self): @@ -2600,16 +2478,16 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): model = Repro().cuda() x = torch.randn((1, 512, 256), device="cuda", requires_grad=True) - out = torch.compile(model, backend="aot_eager", fullgraph=True)(x) + out = torch.compile(model, backend="aot_eager")(x) out.backward(torch.ones_like(out)) @supported_platform - @skip_on_cpu - def test_differentiable_logsumexp_gradcheck(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_differentiable_logsumexp_gradcheck(self): make_tensor = functools.partial( torch.randn, (2, 2, 11, 4), - device=device, + device="cuda", dtype=torch.float64, requires_grad=True, ) @@ -2627,17 +2505,17 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ) @supported_platform - @skip_on_cpu - def test_differentiable_logsumexp_compiled(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_differentiable_logsumexp_compiled(self): make_tensor = functools.partial( torch.randn, (2, 2, 128, 64), - device=device, + device="cuda", dtype=torch.float32, requires_grad=True, ) q, k, v = make_tensor(), make_tensor(), make_tensor() - lse_mask = torch.randn(2, 2, 128, device=device) + lse_mask = torch.randn(2, 2, 128, device="cuda") out, lse = flex_attention(q, k, v, return_lse=True) (out.mean() + (lse * lse_mask).sum()).backward() @@ -2646,9 +2524,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): k.grad = None v.grad = None - out2, lse2 = torch.compile(flex_attention, fullgraph=True)( - q, k, v, return_lse=True - ) + out2, lse2 = torch.compile(flex_attention)(q, k, v, return_lse=True) (out2.mean() + (lse2 * lse_mask).sum()).backward() q_grad2, k_grad2, v_grad2 = q.grad, k.grad, v.grad tolerance = Tolerances(atol=1e-1, rtol=1e-1) @@ -2667,22 +2543,22 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): # Use weird mask to test reusing block_mask does work well. @supported_platform - @skip_on_cpu - def _test_block_mask_reuse_with_weird_mask(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def _test_block_mask_reuse_with_weird_mask(self): def mask(b, h, q, kv): return (kv < 256) | (kv >= 2048) make_tensor = functools.partial( torch.randn, (4, 4, 4096, 64), - device=device, + device="cuda", dtype=torch.float32, requires_grad=True, ) block_mask = create_block_mask(mask, None, None, 4096, 4096) # Compile 1st version with q/k/v(seqlen=4096) and block_mask(seqlen=4096) - torch.compile(flex_attention, dynamic=True, fullgraph=True)( + torch.compile(flex_attention, dynamic=True)( make_tensor(), make_tensor(), make_tensor(), block_mask=block_mask ) @@ -2697,7 +2573,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): # Compile 2st version with q/k/v(seqlen=2048) and block_mask(seqlen=4096), # The graph includes the BlockMask._adjust part. - out = torch.compile(flex_attention, dynamic=True, fullgraph=True)( + out = torch.compile(flex_attention, dynamic=True)( q, k, v, block_mask=block_mask ) out.sum().backward() @@ -2708,7 +2584,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): block_mask2 = create_block_mask(mask, None, None, 2048, 2048) # Reuse the 1st version with q/k/v(seqlen=2048) and block_mask(seqlen=2048) - out2 = torch.compile(flex_attention, dynamic=True, fullgraph=True)( + out2 = torch.compile(flex_attention, dynamic=True)( q, k, v, block_mask=block_mask2 ) out2.sum().backward() @@ -2727,12 +2603,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ) @supported_platform - @skip_on_cpu - def test_float32_matmul_precision(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_float32_matmul_precision(self): make_tensor = functools.partial( torch.zeros, (2, 2, 128, 32), - device=device, + device="cuda", dtype=torch.float32, requires_grad=False, ) @@ -2759,23 +2635,23 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): torch.testing.assert_close(grads_eager, grads_compile) @supported_platform - @skip_on_cpu + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") @common_utils.parametrize("score_mod_name", ["_head_offset"]) @common_utils.parametrize("mode", ["eager", "aot_eager"]) def test_captured_score_mod_aot_eager_gradcheck( - self, device, score_mod_name: str, mode: str + self, score_mod_name: str, mode: str ): make_tensor = functools.partial( torch.randn, (2, 2, 11, 4), - device=device, + device="cuda", dtype=torch.float64, requires_grad=True, ) query, key, value = make_tensor(), make_tensor(), make_tensor() func = torch.compile(flex_attention, backend=mode, fullgraph=True) - score_mod = captured_buffers_map[score_mod_name](torch.float64, device) + score_mod = captured_buffers_map[score_mod_name](torch.float64) self.assertTrue( torch.autograd.gradcheck( @@ -2784,10 +2660,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ) @supported_platform - @skip_on_cpu @common_utils.parametrize("mode", ["eager", "aot_eager"]) def test_document_masking_edge_case(self, device, mode): - requires_grad = device in DEVICE_SUPPORTS_BACKWARDS document_masks = torch.full((2, 128), 0, dtype=torch.int32, device=device) document_masks[:, 64:] = 1 @@ -2800,18 +2674,18 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): (2, 1, 128, 4), device=device, dtype=torch.float64, - requires_grad=requires_grad, + requires_grad=True, ) query, key, value = make_tensor(), make_tensor(), make_tensor() func = torch.compile(flex_attention, backend=mode, fullgraph=True) block_mask = create_block_mask(mask_mod, 2, 1, 128, 128, device=device) out = func(query, key, value, block_mask=block_mask) - if requires_grad: + if device != "cpu": out.sum().backward() @supported_platform - @skip_on_cpu + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") def test_strided_backwards(self): shape = (1, 2, 4096, 64) Q = torch.randn(shape, requires_grad=True, device="cuda") @@ -2850,13 +2724,14 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): dtype = torch.float32 # Setup - requires_grad = device in DEVICE_SUPPORTS_BACKWARDS make_tensor = functools.partial( torch.randn, shape, device=device, dtype=dtype, - requires_grad=False if mode == "paged_attention" else requires_grad, + requires_grad=False + if mode == "paged_attention" + else not self.test_inference_only, ) # Create and permute tensors @@ -2886,7 +2761,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ) @supported_platform - @skip_on_cpu + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") @common_utils.parametrize("mode", ["eager", "inductor"]) @common_utils.parametrize( "permute_order", @@ -2898,14 +2773,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ], ) @common_utils.parametrize("shape", [(2, 5, 128, 16), (4, 2, 64, 16)]) - def test_flex_attention_backward_stride_ordering( - self, device, mode, permute_order, shape - ): + def test_flex_attention_backward_stride_ordering(self, mode, permute_order, shape): from torch._inductor.ir import get_stride_order dtype = torch.float32 make_tensor = functools.partial( - torch.randn, shape, device=device, dtype=dtype, requires_grad=False + torch.randn, shape, device="cuda", dtype=dtype, requires_grad=False ) query, key, value = make_tensor(), make_tensor(), make_tensor() @@ -2943,24 +2816,23 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @common_utils.parametrize("compile", [True, False]) def test_fully_masked_out_rows_0_check(self, device, compile: bool): # Ensure fully masked out rows won't cause NaNs. - requires_grad = device in DEVICE_SUPPORTS_BACKWARDS query = torch.randn( (B, H, S, D), dtype=torch.float32, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) key = torch.randn( (B, H, S, D), dtype=torch.float32, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) value = torch.randn( (B, H, S, D), dtype=torch.float32, device=device, - requires_grad=requires_grad, + requires_grad=not self.test_inference_only, ) M = S // 2 @@ -2973,7 +2845,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): flex = ( torch.compile(flex_attention, dynamic=False) if compile else flex_attention ) - if requires_grad: + if not self.test_inference_only: out, lse = flex(query, key, value, block_mask=block_mask, return_lse=True) self.assertEqual(out[:, :, M:, :].sum(), 0) self.assertTrue((lse[:, :, M:] == -float("inf")).all()) @@ -2987,7 +2859,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): self.assertEqual(out[:, :, M:, :].sum(), 0) @supported_platform - def test_fully_masked_out_rows(self, device): + @common_utils.parametrize("compile", [True, False]) + def test_fully_masked_out_rows(self, device, compile: bool): M = S // 2 def mask_mod(b, h, q, kv): @@ -2999,12 +2872,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): return score self.run_test( - noop_mod, torch.float32, device, B, H, S, D, B, H, S, D, block_mask + noop_mod, torch.float32, B, H, S, D, B, H, S, D, block_mask, device=device ) @supported_platform - @skip_on_cpu - def test_kernel_options_argument_is_respected(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Only test on cuda for this kernel option") + def test_kernel_options_argument_is_respected(self): make_tensor = functools.partial( torch.randn, (2, 2, 128, 64), @@ -3016,14 +2889,80 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): # Ensure we respect user's input kernel options. _, code = run_and_get_code( - torch.compile(flex_attention, fullgraph=True), - q, - k, - v, - kernel_options={"BLOCK_M": 16}, + torch.compile(flex_attention), q, k, v, kernel_options={"BLOCK_M": 16} ) FileCheck().check("BLOCK_M : tl.constexpr = 16").run(code[0]) + @supported_platform + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_comparison_vs_sdpa(self): + def causal(score, b, h, q_idx, kv_idx): + return torch.where(q_idx >= kv_idx, score, -float("inf")) + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + no_sparse_flex = functools.partial(flex_attention, score_mod=causal) + score_mod_sparse_flex = functools.partial( + flex_attention, + score_mod=causal, + block_mask=create_block_mask(causal_mask, 1, 1, 2048, 2048), + ) + mask_mod_sparse_flex = functools.partial( + flex_attention, block_mask=create_block_mask(causal_mask, 1, 1, 2048, 2048) + ) + for attention_call in [ + no_sparse_flex, + score_mod_sparse_flex, + mask_mod_sparse_flex, + ]: + inputs = [ + torch.randn( + 2, + 2, + 2048, + 64, + device="cuda", + dtype=torch.float16, + requires_grad=True, + ) + for _ in range(3) + ] + gradOut = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float16) + out_ref = torch.nn.functional.scaled_dot_product_attention( + *inputs, is_causal=True + ) + out_ref.backward(gradOut) + + inputs_flex = [i.detach().clone().requires_grad_(True) for i in inputs] + out_flex = torch.compile(attention_call)(*inputs_flex) + out_flex.backward(gradOut) + inputs_golden = [ + i.detach().clone().to(dtype=torch.float64).requires_grad_(True) + for i in inputs + ] + out_golden = torch.nn.functional.scaled_dot_product_attention( + *inputs_golden, is_causal=True + ) + out_golden.backward(gradOut.to(dtype=torch.float64)) + + for ref, flex, golden in [ + (out_ref, out_flex, out_golden), + (inputs[0].grad, inputs_flex[0].grad, inputs_golden[0].grad), + (inputs[1].grad, inputs_flex[1].grad, inputs_golden[1].grad), + (inputs[2].grad, inputs_flex[2].grad, inputs_golden[2].grad), + ]: + ref_error = rmse(ref, golden) + flex_error = rmse(flex, golden) + # Note: This has been carefully tested that FlexAttention is within + # 20% of the average error of SDPA! Do not bump this tolerance + # unless you are absolutely sure you are not worsening the accuracy + # of FlexAttention! + self.assertTrue( + ref_error * 1.2 > flex_error, + f"Ref error: {ref_error}, Flex Error: {flex_error}", + ) + @supported_platform def test_block_mask_non_divisible(self, device): seq = torch.arange(1023, device=device) // 128 @@ -3035,12 +2974,249 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): torch.compile(create_block_mask)(mod, None, None, 1023, 1023, device=device) self.run_test_with_call( lambda q, k, v: flex_attention(q, k, v, block_mask=block_mask), - torch.float16, - device, Q_S=1023, KV_S=1023, + device=device, ) + @supported_platform + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_head_bias_req_grad(self): + B, H, S, D = 1, 4, 256, 64 + bias = torch.randn(H, device="cuda", dtype=torch.float16, requires_grad=True) + + bias_flex = bias.detach().clone().requires_grad_(True) + + def head_bias(score, b, h, q_idx, kv_idx): + return score + bias_flex[h] + + bias_sdpa_ref = bias.detach().clone().requires_grad_(True) + implicit_bias_sdpa_ref = bias_sdpa_ref + implicit_bias_sdpa_ref = implicit_bias_sdpa_ref.view(H, 1, 1).expand(H, S, S) + bias_sdpa_gold = ( + bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) + ) + implicit_bias_sdpa_gold = bias_sdpa_gold + implicit_bias_sdpa_gold = implicit_bias_sdpa_gold.view(H, 1, 1).expand(H, S, S) + + self._test_learnable_bias_inner( + B, + H, + S, + D, + head_bias, + bias_flex, + implicit_bias_sdpa_ref, + bias_sdpa_ref, + implicit_bias_sdpa_gold, + bias_sdpa_gold, + ) + + @supported_platform + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_comparison_vs_sdpa_with_learnable_bias(self): + # 1-dimensional bias: + B, H, S, D = 1, 1, 256, 64 + bias = torch.randn( + 2 * S, device="cuda", dtype=torch.float16, requires_grad=True + ) + + bias_flex = bias.detach().clone().requires_grad_(True) + + def rel_pos_1d(score, b, h, q_idx, kv_idx): + return score + bias_flex[q_idx + kv_idx] + + bias_indices = torch.arange(S)[:, None] + torch.arange(S) + bias_sdpa_ref = bias.detach().clone().requires_grad_(True) + implicit_bias_sdpa_ref = bias_sdpa_ref[bias_indices] + bias_sdpa_gold = ( + bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) + ) + implicit_bias_sdpa_gold = bias_sdpa_gold[bias_indices] + + self._test_learnable_bias_inner( + B, + H, + S, + D, + rel_pos_1d, + bias_flex, + implicit_bias_sdpa_ref, + bias_sdpa_ref, + implicit_bias_sdpa_gold, + bias_sdpa_gold, + ) + + # 2-dimensional bias: + B, H, S, D = 1, 1, 256, 64 + bias = torch.randn(S, S, device="cuda", dtype=torch.float16, requires_grad=True) + + bias_flex = bias.detach().clone().requires_grad_(True) + + def rel_pos_2d(score, b, h, q_idx, kv_idx): + return score + bias_flex[q_idx, kv_idx] + + bias_sdpa_ref = bias.detach().clone().requires_grad_(True) + implicit_bias_sdpa_ref = bias_sdpa_ref + bias_sdpa_gold = ( + bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) + ) + implicit_bias_sdpa_gold = bias_sdpa_gold + + self._test_learnable_bias_inner( + B, + H, + S, + D, + rel_pos_2d, + bias_flex, + implicit_bias_sdpa_ref, + bias_sdpa_ref, + implicit_bias_sdpa_gold, + bias_sdpa_gold, + ) + + # 2-dimensional bias + index multiple + B, H, S, D = 1, 1, 256, 64 + bias = torch.randn(S, S, device="cuda", dtype=torch.float16, requires_grad=True) + + bias_flex = bias.detach().clone().requires_grad_(True) + + def rel_pos_2d(score, b, h, q_idx, kv_idx): + return score + bias_flex[q_idx][kv_idx] + + bias_sdpa_ref = bias.detach().clone().requires_grad_(True) + implicit_bias_sdpa_ref = bias_sdpa_ref + bias_sdpa_gold = ( + bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) + ) + implicit_bias_sdpa_gold = bias_sdpa_gold + + self._test_learnable_bias_inner( + B, + H, + S, + D, + rel_pos_2d, + bias_flex, + implicit_bias_sdpa_ref, + bias_sdpa_ref, + implicit_bias_sdpa_gold, + bias_sdpa_gold, + ) + + # 2-dimensional bias + transposed: + B, H, S, D = 1, 1, 256, 64 + bias = torch.randn(S, S, device="cuda", dtype=torch.float16, requires_grad=True) + + bias_flex = bias.detach().clone().requires_grad_(True) + + def rel_pos_2d_transposed(score, b, h, q_idx, kv_idx): + return score + bias_flex[kv_idx, q_idx] + + bias_sdpa_ref = bias.detach().clone().requires_grad_(True) + implicit_bias_sdpa_ref = bias_sdpa_ref.transpose(-1, -2) + bias_sdpa_gold = ( + bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) + ) + implicit_bias_sdpa_gold = bias_sdpa_gold.transpose(-1, -2) + + self._test_learnable_bias_inner( + B, + H, + S, + D, + rel_pos_2d_transposed, + bias_flex, + implicit_bias_sdpa_ref, + bias_sdpa_ref, + implicit_bias_sdpa_gold, + bias_sdpa_gold, + ) + + # 3-dimensional bias + transposed + B, H, S, D = 4, 8, 256, 64 + bias = torch.randn( + H, S, S, device="cuda", dtype=torch.float16, requires_grad=True + ) + + bias_flex = bias.detach().clone().requires_grad_(True) + + def rel_pos_3d_transposed(score, b, h, q_idx, kv_idx): + return score + bias_flex[h, kv_idx, q_idx] + + bias_sdpa_ref = bias.detach().clone().requires_grad_(True) + implicit_bias_sdpa_ref = bias_sdpa_ref.transpose(-1, -2) + bias_sdpa_gold = ( + bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) + ) + implicit_bias_sdpa_gold = bias_sdpa_gold.transpose(-1, -2) + + self._test_learnable_bias_inner( + B, + H, + S, + D, + rel_pos_3d_transposed, + bias_flex, + implicit_bias_sdpa_ref, + bias_sdpa_ref, + implicit_bias_sdpa_gold, + bias_sdpa_gold, + ) + + def _test_learnable_bias_inner( + self, + B, + H, + S, + D, + score_mod, + bias_flex, + implicit_bias_sdpa_ref, + bias_sdpa_ref, + implicit_bias_sdpa_gold, + bias_sdpa_gold, + ): + make_tensor = functools.partial( + torch.ones, + (B, H, S, D), + device="cuda", + dtype=torch.float16, + requires_grad=True, + ) + q_ref, k_ref, v_ref = make_tensor(), make_tensor(), make_tensor() + q_gold, k_gold, v_gold = query_key_value_clones( + q_ref, k_ref, v_ref, torch.float64 + ) + q_flex, k_flex, v_flex = query_key_value_clones(q_ref, k_ref, v_ref) + + out_ref = torch.nn.functional.scaled_dot_product_attention( + q_ref, k_ref, v_ref, attn_mask=implicit_bias_sdpa_ref + ) + out_ref.sum().backward() + out_gold = torch.nn.functional.scaled_dot_product_attention( + q_gold, k_gold, v_gold, attn_mask=implicit_bias_sdpa_gold + ) + out_gold.sum().backward() + out_flex = flex_attention(q_flex, k_flex, v_flex, score_mod=score_mod) + out_flex.sum().backward() + + name = score_mod.__name__ + for ref, flex, gold in [ + (out_ref, out_flex, out_gold), + (q_ref.grad, q_flex.grad, q_gold.grad), + (k_ref.grad, k_flex.grad, k_gold.grad), + (v_ref.grad, v_flex.grad, v_gold.grad), + (bias_sdpa_ref.grad, bias_flex.grad, bias_sdpa_gold.grad), + ]: + ref_error = rmse(ref, gold) + flex_error = rmse(flex, gold) + self.assertTrue( + ref_error * 1.2 >= flex_error, + f"{name} -> Ref error: {ref_error}, Flex eager Error: {flex_error}", + ) + @supported_platform def test_causal_block_non_divisible(self, device): def mask_mod(b, h, q, kv): @@ -3049,14 +3225,14 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): block_mask = create_block_mask(mask_mod, B, 1, S - 1, S - 1, device=device) attention = functools.partial(flex_attention, block_mask=block_mask) - self.run_test_with_call(attention, torch.float16, device, Q_S=S - 1, KV_S=S - 1) + self.run_test_with_call(attention, Q_S=S - 1, KV_S=S - 1, device=device) @supported_platform - @skip_on_cpu - def test_modular_indexing(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_modular_indexing(self): B, H, N, D = 100, 12, 128, 64 dtype = torch.bfloat16 - device = torch.device(device) + device = torch.device("cuda") class Attention(torch.nn.Module): def __init__(self): @@ -3089,13 +3265,13 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): m(q, k, v) @supported_platform - @skip_on_cpu - def test_force_write_lse(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_force_write_lse(self): dtype = torch.float32 make_tensor = functools.partial( torch.randn, (2, 2, 128, 16), - device=device, + device="cuda", dtype=dtype, requires_grad=False, ) @@ -3106,16 +3282,16 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): out_compiled, lse_compiled = flex_compile(query, key, value, return_lse=True) out_paged, lse_paged = self.run_paged_attention( - score_mod=_identity, q=query, k=key, v=value, dtype=dtype, device=device + score_mod=_identity, q=query, k=key, v=value, dtype=dtype ) torch.testing.assert_close(lse_eager, lse_compiled, atol=3e-3, rtol=0) torch.testing.assert_close(lse_eager, lse_paged, atol=3e-3, rtol=0) @supported_platform - @skip_on_cpu + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") @common_utils.parametrize("backend", ["flex_attention", "flex_decode", "eager"]) - def test_lse_masked_output(self, device, backend): + def test_lse_masked_output(self, backend): if backend == "flex_decode": kernel_options = {"FORCE_USE_FLEX_ATTENTION": False} flex_call = torch.compile(flex_attention, fullgraph=True) @@ -3133,7 +3309,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): make_tensor = functools.partial( torch.randn, (2, 2, N_CTX, 64), - device=device, + device="cuda", dtype=torch.float32, requires_grad=True, ) @@ -3196,11 +3372,11 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): torch.testing.assert_close(flex_v_grad, v.grad, atol=3e-3, rtol=2e-3) @supported_platform - @skip_on_cpu - def test_mixed_device_error_message(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip test on CPU only as mixed devices needed") + def test_mixed_device_error_message(self): # Create tensors on different devices cpu_tensor = torch.randn(2, 2, 128, 16, device="cpu") - cuda_tensor = torch.randn(2, 2, 128, 16, device=device) + cuda_tensor = torch.randn(2, 2, 128, 16, device="cuda") # Use different devices for query, key, and value query, key, value = cpu_tensor, cuda_tensor, cpu_tensor @@ -3215,8 +3391,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): flex_attention(query, key, value) @supported_platform - @skip_on_cpu - def test_captured_wrong_device_error_message(self, device): + @unittest.skipIf( + SKIP_UT_ON_CPU, "Skip test on CPU only as wrong cuda device needed" + ) + def test_captured_wrong_device_error_message(self): means = torch.randn(64, 3).cuda() length_scales = torch.logspace(0.001, 0.1, 8) @@ -3230,13 +3408,15 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): expected_error_message = "Buffers cannot be created" - q, k, v = (torch.randn(1, 8, 64, 64, device=device) for _ in range(3)) + q, k, v = (torch.randn(1, 8, 64, 64, device="cuda") for _ in range(3)) with self.assertRaisesRegex(RuntimeError, expected_error_message): torch.compile(flex_attention)(q, k, v, score_mod=euclidean_dist_pos_embed) @supported_platform - @skip_on_cpu - def test_cant_lower_error_message(self, device): + @unittest.skipIf( + SKIP_UT_ON_CPU, "Skip test on CPU only as wrong cuda device needed" + ) + def test_cant_lower_error_message(self): # We can't lower a 256-element reduction inside a pointwise reduction means = torch.randn(64, 256).cuda() length_scales = torch.logspace(0.001, 0.1, 8).cuda() @@ -3251,16 +3431,16 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): expected_error_message = "Buffers cannot be created" - q, k, v = (torch.randn(1, 8, 64, 64, device=device) for _ in range(3)) + q, k, v = (torch.randn(1, 8, 64, 64, device="cuda") for _ in range(3)) with self.assertRaisesRegex(RuntimeError, expected_error_message): torch.compile(flex_attention)(q, k, v, score_mod=euclidean_dist_pos_embed) @supported_platform - @skip_on_cpu - def test_reduction_unrolled(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_reduction_unrolled(self): # We can't lower a 256-element reduction inside a pointwise reduction - means = torch.randn(S, 3).to(device) - length_scales = torch.logspace(0.001, 0.1, H).to(device) + means = torch.randn(S, 3).cuda() + length_scales = torch.logspace(0.001, 0.1, H).cuda() def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx): q_pos = means[q_idx] @@ -3270,13 +3450,13 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): inv_dist = torch.exp(-dist / scale) return inv_dist * score - self.run_test(euclidean_dist_pos_embed, torch.bfloat16, device) + self.run_test(euclidean_dist_pos_embed, torch.bfloat16) @supported_platform - @skip_on_cpu - def test_invalid_block_size(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip test on CPU only as error_message on cuda") + def test_invalid_block_size(self): # Create tensors on different devices - q, k, v = (torch.randn(1, 8, 128, 64, device=device) for _ in range(3)) + q, k, v = (torch.randn(1, 8, 128, 64, device="cuda") for _ in range(3)) expected_error_message = ( "ValueError: Q and KV block size must be divisible by BLOCK_M and BLOCK_N." @@ -3287,12 +3467,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): torch.compile(flex_attention)(q, k, v, block_mask=block_mask) @supported_platform - @skip_on_cpu - def test_small_q_kv_len(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip test on CPU only as kernel_options on cuda") + def test_small_q_kv_len(self): make_tensor = functools.partial( torch.ones, (1, 1, 1, 16), - device=device, + device="cuda", dtype=torch.float32, requires_grad=True, ) @@ -3316,8 +3496,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): torch.testing.assert_close(grads_eager, grads_compile) @supported_platform - @skip_on_cpu - def test_dynamic_shapes_bug_dynamic_batch(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip test on CPU only as cuda device expected") + def test_dynamic_shapes_bug_dynamic_batch(self): def _flex_attention_mask(b, h, q_idx, kv_idx, input_lengths): padding_condition = (q_idx < input_lengths[b]) & (kv_idx < input_lengths[b]) return padding_condition @@ -3363,8 +3543,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): model = Model(128).cuda() B, F, T = 16, 256, 12 for _ in range(5): - x = torch.randn(B, T, F, device=device) - l = torch.randint(0, T, (B,), device=device) + x = torch.randn(B, T, F, device="cuda") + l = torch.randint(0, T, (B,), device="cuda") model(x, l) assert ( @@ -3372,12 +3552,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ), f"Expected 1 graph, but got {counter.frame_count} graphs" @supported_platform - @skip_on_cpu - def test_dynamic_shapes_with_custom_kernel_options(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_dynamic_shapes_with_custom_kernel_options(self): make_tensor = functools.partial( torch.ones, (8, 8, 1024, 64), - device=device, + device="cuda", dtype=torch.bfloat16, ) query, key, value = make_tensor(), make_tensor(), make_tensor() @@ -3412,12 +3592,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): torch.testing.assert_close(out_eager, out_compiled, atol=3e-3, rtol=2e-3) @supported_platform - @skip_on_cpu - def test_zero_length_sequence_error(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_zero_length_sequence_error(self): make_tensor = functools.partial( torch.ones, (8, 8, 0, 64), # Zero in sequence dimension - device=device, + device="cuda", dtype=torch.bfloat16, ) query, key, value = make_tensor(), make_tensor(), make_tensor() @@ -3449,9 +3629,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): attention = functools.partial(flex_attention, block_mask=block_mask) - self.run_test_with_call( - attention, Q_S=Q_S, KV_S=KV_S, dtype=torch.bfloat16, device=device - ) + self.run_test_with_call(attention, Q_S=Q_S, KV_S=KV_S, device=device) @supported_platform def test_non_divisible_with_captured_buffer(self, device): @@ -3467,21 +3645,13 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): flex_attention, score_mod=apply_multiplicative_bias ) - self.run_test_with_call( - attention, Q_S=Q_S, KV_S=KV_S, dtype=torch.bfloat16, device=device - ) + self.run_test_with_call(attention, Q_S=Q_S, KV_S=KV_S, device=device) @supported_platform def test_num_warps_8_error(self, device): attention = functools.partial(flex_attention, score_mod=_identity) self.run_test_with_call( - attention, - dtype=torch.float16, - device=device, - Q_S=128, - KV_S=128, - Q_D=128, - V_D=128, + attention, Q_S=128, KV_S=128, Q_D=128, V_D=128, device=device ) @supported_platform @@ -3506,8 +3676,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): torch.compile(flex_attention)(query, key, value, block_mask=block_mask) @supported_platform - @skip_on_cpu - def test_free_symbol_dynamic(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_free_symbol_dynamic(self): def batch_flip_causal(b, h, q_idx, kv_idx): return (q_idx >= kv_idx) & (b % 2 == 0) @@ -3547,7 +3717,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): H=None, Q_LEN=sequence_len, KV_LEN=sequence_len, - device=device, + device="cuda", ) # Run forward pass @@ -3557,8 +3727,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2) @supported_platform - @skip_on_cpu - def test_symbol_closure_in_score_mod(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as cuda backends needed") + def test_symbol_closure_in_score_mod(self): class SimpleAttention(torch.nn.Module): def __init__(self, dim=512, n_head=8): super().__init__() @@ -3600,13 +3770,13 @@ def forward(self, child : torch.Tensor, child_1 : torch.Tensor, child_2 : torch. ) @supported_platform - @skip_on_cpu - def test_fw_bw_graph_correctness(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_fw_bw_graph_correctness(self): cnt = CompileCounterWithBackend("aot_eager") make_tensor = functools.partial( torch.randn, (2, 2, 128, 4), - device=device, + device="cuda", dtype=torch.float64, requires_grad=True, ) @@ -3712,7 +3882,7 @@ class GraphModule(torch.nn.Module): ) @supported_platform - @skip_on_cuda + @unittest.skipIf(TEST_ON_CUDA, "Testing CPU error message") def test_cpu_error_message_return_lse(self, device): make_tensor = functools.partial( torch.randn, @@ -3751,10 +3921,10 @@ class GraphModule(torch.nn.Module): self.assertEqual(attn_output.device, torch.device("cuda:1")) @supported_platform - @skip_on_cpu - def test_validate_small_embedding_size_error_message(self, device): + @unittest.skipIf(SKIP_UT_ON_CPU, "Skip on CPU as not supported") + def test_validate_small_embedding_size_error_message(self): # eager support for small embedding size - q, k, v = [torch.randn(2, 2, 128, 8, device=device) for _ in range(3)] + q, k, v = [torch.randn(2, 2, 128, 8, device="cuda") for _ in range(3)] flex_attention(q, k, v) # compiled cpu support for small embedding size @@ -3762,7 +3932,7 @@ class GraphModule(torch.nn.Module): flex_attention(q, k, v) # compiled gpu kernel does not support small embedding size - q, k, v = [torch.randn(2, 2, 128, 8, device=device) for _ in range(3)] + q, k, v = [torch.randn(2, 2, 128, 8, device="cuda") for _ in range(3)] compiled_fa = torch.compile(flex_attention) with self.assertRaisesRegex( @@ -3773,16 +3943,16 @@ class GraphModule(torch.nn.Module): compiled_fa(q, k, v) # compiled gpu kernel supports large embedding size - q, k, v = [torch.randn(2, 2, 128, 16, device=device) for _ in range(3)] + q, k, v = [torch.randn(2, 2, 128, 16, device="cuda") for _ in range(3)] compiled_fa = torch.compile(flex_attention) @unittest.skipIf( not has_triton() or not HAS_WARP_SPEC, reason="FBCODE Triton is required for this test", ) - def test_triton_template_warp_specialization(self, device): + def test_triton_template_warp_specialization(self): def make_tensor(): - return torch.rand(4, 16, 4096, 64, device=device, dtype=torch.bfloat16) + return torch.rand(4, 16, 4096, 64, device="cuda", dtype=torch.bfloat16) q, k, v = make_tensor(), make_tensor(), make_tensor() flex_compiled = torch.compile(flex_attention, fullgraph=True) @@ -3822,10 +3992,14 @@ class GraphModule(torch.nn.Module): class TestBlockMask(InductorTestCase): def setUp(self): super().setUp() + if test_device[0] == "cpu": + self.skipTest( + "skip UT for CPUs as 'BlockMask' is common and covered on CUDA" + ) @supported_platform - def test_block_mask_attributes(self, device): - offset = torch.zeros(8, device=device) + def test_block_mask_attributes(self): + offset = torch.zeros(8, device="cuda") def causal_mask(b, h, q, kv): return (q + (offset[b] * 128)) >= kv @@ -3840,7 +4014,7 @@ class TestBlockMask(InductorTestCase): self.assertEqual(block_mask[1, 0].sparsity(), 46.875) self.assertEqual(block_mask.sparsity(), block_mask[1].sparsity()) - offset = torch.arange(8, device=device) + offset = torch.arange(8, device="cuda") block_mask = create_block_mask(causal_mask, 8, 1, 2048, 2048) self.assertEqual(block_mask.sparsity(), 29.1015625) self.assertTrue(block_mask.sparsity() < block_mask[0].sparsity()) @@ -3848,7 +4022,7 @@ class TestBlockMask(InductorTestCase): @supported_platform @common_utils.parametrize("BLOCK_SIZE", [32, 64, 128, 256, (32, 64), (64, 32)]) - def test_block_size_changes(self, device, BLOCK_SIZE: Union[int, tuple[int, int]]): + def test_block_size_changes(self, BLOCK_SIZE: Union[int, tuple[int, int]]): B, H, Q_LEN, KV_LEN = 4, 2, 2048, 2048 if isinstance(BLOCK_SIZE, int): @@ -3865,8 +4039,8 @@ class TestBlockMask(InductorTestCase): self.assertEqual(block_mask.shape, (B, H, Q_LEN, KV_LEN)) @supported_platform - def test_getitem(self, device): - offset = torch.zeros(8, device=device) + def test_getitem(self): + offset = torch.zeros(8, device="cuda") def causal_mask(b, h, q, kv): return (q + (offset[b] * 128)) >= kv @@ -3922,8 +4096,8 @@ class TestBlockMask(InductorTestCase): ) @supported_platform - def test_block_mask_device_change(self, device): - offset = torch.zeros(8, device=device) + def test_block_mask_device_change(self): + offset = torch.zeros(8, device="cuda") def causal_mask(b, h, q, kv): return (q + (offset[b] * 128)) >= kv @@ -3947,8 +4121,8 @@ class TestBlockMask(InductorTestCase): assert block_mask.q_num_blocks.is_cuda @supported_platform - def test_compiling_create_block_mask(self, device): - seq = torch.arange(512, device=device) // 127 + def test_compiling_create_block_mask(self): + seq = torch.arange(512, device="cuda") // 127 def mask_mod(b, h, q, kv): return (q >= kv) & (seq[q] == seq[kv]) @@ -3961,7 +4135,7 @@ class TestBlockMask(InductorTestCase): self.assertEqual(block_mask.kv_indices.shape, torch.Size((1, 1, 4, 4))) @supported_platform - def test_compiling_create_block_mask_no_recompile(self, device): + def test_compiling_create_block_mask_no_recompile(self): def mask_mod(b, h, q, kv): return q >= kv @@ -3987,7 +4161,7 @@ class TestBlockMask(InductorTestCase): self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2) @supported_platform - def test_block_mask_viz(self, device): + def test_block_mask_viz(self): def causal_mask(b, h, q, kv): return q >= kv @@ -4027,7 +4201,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s )""", ) - offset = torch.arange(8, device=device) + offset = torch.arange(8, device="cuda") def causal_offset_mask(b, h, q, kv): return (q + offset[b] * 128) >= kv @@ -4063,7 +4237,8 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s @supported_platform @common_utils.parametrize("full_indices", [False, True]) - def test_from_kv_blocks(self, device, full_indices: bool): + def test_from_kv_blocks(self, full_indices: bool): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ( kv_num_blocks, kv_indices, @@ -4116,7 +4291,8 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s self.assertIsNone(block_mask.full_q_indices) @supported_platform - def test_block_size(self, device): + def test_block_size(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") kv_num_blocks, kv_indices, _, _ = self.generate_test_inputs(False, device) block_mask = BlockMask.from_kv_blocks(kv_num_blocks, kv_indices) self.assertEqual( @@ -4131,11 +4307,11 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s self.assertEqual(block_mask_custom.BLOCK_SIZE, custom_block_size) @supported_platform - def test_upcast_appropriately(self, device): - q = torch.randn((1, 1, 128, 16), dtype=torch.float16, device=device) - k = torch.randn((1, 1, 128, 16), dtype=torch.float16, device=device) - v = torch.randn((1, 1, 128, 16), dtype=torch.float16, device=device) - mass = torch.ones((1), dtype=torch.float16, device=device) + def test_upcast_appropriately(self): + q = torch.randn((1, 1, 128, 16), dtype=torch.float16, device="cuda") + k = torch.randn((1, 1, 128, 16), dtype=torch.float16, device="cuda") + v = torch.randn((1, 1, 128, 16), dtype=torch.float16, device="cuda") + mass = torch.ones((1), dtype=torch.float16, device="cuda") def score_mod(score, b, h, q_idx, kv_idx): return score + torch.log(mass[0]) @@ -4143,7 +4319,8 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s torch.compile(flex_attention)(q, k, v, score_mod=score_mod) @supported_platform - def test_init_mismatched_full_kv(self, device): + def test_init_mismatched_full_kv(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") kv_num_blocks, kv_indices, full_kv_num_blocks, _ = self.generate_test_inputs( True, device ) @@ -4164,7 +4341,8 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s ) @supported_platform - def test_init_mismatched_full_q(self, device): + def test_init_mismatched_full_q(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") kv_num_blocks, kv_indices, _, _ = self.generate_test_inputs(False, device) with self.assertRaises(AssertionError): @@ -4184,7 +4362,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s @supported_platform @common_utils.parametrize("compile", [False, True]) - def test_no_q_info(self, device, compile: bool): + def test_no_q_info(self, compile: bool): def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx @@ -4206,7 +4384,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s 2, 2048, 64, - device=device, + device="cuda", dtype=torch.float16, requires_grad=True, ) @@ -4221,7 +4399,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s torch.testing.assert_close(causal_mask_out, sdpa_mask_out, atol=5e-3, rtol=0.0) @supported_platform - def test_doc_mask_clamped_repro(self, device): + def test_doc_mask_clamped_repro(self): def _offsets_to_doc_ids_tensor(offsets): device = offsets.device counts = offsets[1:] - offsets[:-1] @@ -4303,7 +4481,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s torch.compile(flex_attention)(q, k, v, block_mask=block_mask) @supported_platform - def test_eager_tracing_correctness(self, device): + def test_eager_tracing_correctness(self): qk_dims = 64 v_dims = 128 q_heads = 4 @@ -4311,7 +4489,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s seq_len = 256 batch_size = 1 - make_tensor = functools.partial(torch.randn, device=device, dtype=torch.float16) + make_tensor = functools.partial(torch.randn, device="cuda", dtype=torch.float16) q = make_tensor(*(batch_size, q_heads, seq_len, qk_dims)) k = make_tensor(*(batch_size, kv_heads, seq_len, qk_dims)) v = make_tensor(*(batch_size, kv_heads, seq_len, v_dims)) @@ -4333,7 +4511,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s ) @supported_platform - def test_create_is_cuda_graphable(self, device): + def test_create_is_cuda_graphable(self): def mask_mod(b, h, q, kv): return q >= kv @@ -4346,7 +4524,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s @common_utils.parametrize("compile", [False, True]) @supported_platform - def test_block_mask_vs_sequence_lengths(self, device, compile): + def test_block_mask_vs_sequence_lengths(self, compile): if compile: flex_attention_call = torch.compile(flex_attention) else: @@ -4358,7 +4536,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s def create_inputs(S): q, k, v = ( torch.randn( - 1, 8, S, 64, dtype=torch.float16, requires_grad=True, device=device + 1, 8, S, 64, dtype=torch.float16, requires_grad=True, device="cuda" ) for _ in range(3) ) @@ -4378,10 +4556,11 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s class TestPagedAttention(InductorTestCase): def setUp(self): super().setUp() - skipCPUIf( - LONG_COMPILATION_ON_CPU, - "skip UT for CPU due to long compilation time found in CI", - ) + if test_device[0] == "cpu": + if LONG_COMPILATION_ON_CPU: + self.skipTest( + "skip UT for CPU due to long compilation time found in CI" + ) def _check_equal( self, @@ -4400,7 +4579,7 @@ class TestPagedAttention(InductorTestCase): msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." self.assertTrue(False, msg) - def allocate_page_cache(self, n_pages: int, page_size: int, device: str): + def allocate_page_cache(self, n_pages: int, page_size: int, device="cuda"): max_batch_size = 3 paged_cache = PagedAttention(n_pages, page_size, max_batch_size, device=device) return paged_cache @@ -4687,8 +4866,7 @@ class TestPagedAttention(InductorTestCase): self.assertEqual(k_cache, expected_cache) @supported_platform - @dtypes(*device_configs["cpu"].dtypes) - @dtypesIfCUDA(*device_configs["cuda"].dtypes) + @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_paged_builtin_score_mods( self, device, dtype: torch.dtype, score_mod: Callable @@ -4818,7 +4996,7 @@ def get_params(dtypes: list[torch.dtype]) -> list[Params]: supports_learnable_bias = unittest.skipUnless( (torch.cuda.is_available() and has_triton()) - and (torch.cuda.get_device_capability() >= (8, 0) or torch.version.hip), + and (torch.cuda.get_device_capability() == (8, 0) or torch.version.hip), "Requires Triton + A100 or Triton + ROCm", ) @@ -4895,7 +5073,7 @@ class TestLearnableBiases(InductorTestCase): self._gold_check(eager, compiled, gold, name) @common_utils.parametrize( - "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" ) @common_utils.parametrize("mode", ["default", "max-autotune-no-cudagraphs"]) def test_relative_1d_bias(self, params, mode: str): @@ -4928,7 +5106,7 @@ class TestLearnableBiases(InductorTestCase): ) @common_utils.parametrize( - "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" ) def test_absolute_2d_bias(self, params): query, key, value = self._init_tensors(params) @@ -4961,7 +5139,7 @@ class TestLearnableBiases(InductorTestCase): ) @common_utils.parametrize( - "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" ) def test_head_specific_bias(self, params): query, key, value = self._init_tensors(params) @@ -4995,7 +5173,7 @@ class TestLearnableBiases(InductorTestCase): ) @common_utils.parametrize( - "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" ) def test_batch_head_bias(self, params): query, key, value = self._init_tensors(params) @@ -5030,7 +5208,7 @@ class TestLearnableBiases(InductorTestCase): ) @common_utils.parametrize( - "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" ) def test_multiplicative_bias(self, params): query, key, value = self._init_tensors(params) @@ -5062,7 +5240,7 @@ class TestLearnableBiases(InductorTestCase): ) @common_utils.parametrize( - "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" ) def test_local_window_bias(self, params): query, key, value = self._init_tensors(params) @@ -5096,7 +5274,7 @@ class TestLearnableBiases(InductorTestCase): ) @common_utils.parametrize( - "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" ) def test_global_tokens_bias(self, params): query, key, value = self._init_tensors(params) @@ -5128,7 +5306,7 @@ class TestLearnableBiases(InductorTestCase): ) @common_utils.parametrize( - "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" ) def test_weird_bias(self, params): query, key, value = self._init_tensors(params) @@ -5164,7 +5342,7 @@ class TestLearnableBiases(InductorTestCase): ) @common_utils.parametrize( - "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" ) def test_indirect_bias(self, params): query, key, value = self._init_tensors(params) @@ -5203,7 +5381,7 @@ class TestLearnableBiases(InductorTestCase): ) @common_utils.parametrize( - "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" + "params", get_params([torch.float32]), name_fn=lambda x: f"{x}" ) @common_utils.parametrize("mode", ["default", "max-autotune-no-cudagraphs"]) def test_symmetric_bias(self, params, mode: str): @@ -5240,7 +5418,7 @@ class TestLearnableBiases(InductorTestCase): ) @common_utils.parametrize( - "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" ) def test_flipped_indexed_bias(self, params): query, key, value = self._init_tensors(params) @@ -5273,7 +5451,7 @@ class TestLearnableBiases(InductorTestCase): ) @common_utils.parametrize( - "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" ) @common_utils.parametrize("mode", ["default", "max-autotune-no-cudagraphs"]) def test_head_specific_gate(self, params, mode: str): @@ -5306,7 +5484,7 @@ class TestLearnableBiases(InductorTestCase): ) @common_utils.parametrize( - "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" ) def test_distinct_biases(self, params): query, key, value = self._init_tensors(params) @@ -5354,7 +5532,7 @@ class TestLearnableBiases(InductorTestCase): ) @common_utils.parametrize( - "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" + "params", get_params(test_dtypes), name_fn=lambda x: f"{x}" ) def test_relative_1d_bias_only_grad(self, params): query, key, value = self._init_tensors(params) @@ -5390,9 +5568,9 @@ class TestLearnableBiases(InductorTestCase): ) def test_flex_attention_with_dynamic_max_autotune(self): - query = torch.randn(2, 16, 512, 64, device=self.device) - key = torch.randn(2, 16, 512, 64, device=self.device) - value = torch.randn(2, 16, 512, 64, device=self.device) + query = torch.randn(2, 16, 512, 64, device="cuda") + key = torch.randn(2, 16, 512, 64, device="cuda") + value = torch.randn(2, 16, 512, 64, device="cuda") query.requires_grad = True key.requires_grad = True value.requires_grad = True @@ -5433,265 +5611,18 @@ class TestLearnableBiases(InductorTestCase): return (q_idx - kv_idx).abs() < val sliding_window2 = functools.partial( - sliding_window, val=torch.randn((), device=self.device) + sliding_window, val=torch.randn((), device="cuda") ) opt_fn = torch.compile(create_block_mask, fullgraph=True) create_block_mask(sliding_window2, None, None, 1024, 1024) # checks that the compile is working opt_fn(sliding_window2, None, None, 1024, 1024) - @supported_platform - def test_head_bias_req_grad(self): - device = self.device - B, H, S, D = 1, 4, 256, 64 - bias = torch.randn( - H, device=self.device, dtype=torch.float16, requires_grad=True - ) - - bias_flex = bias.detach().clone().requires_grad_(True) - - def head_bias(score, b, h, q_idx, kv_idx): - return score + bias_flex[h] - - bias_sdpa_ref = bias.detach().clone().requires_grad_(True) - implicit_bias_sdpa_ref = bias_sdpa_ref - implicit_bias_sdpa_ref = implicit_bias_sdpa_ref.view(H, 1, 1).expand(H, S, S) - bias_sdpa_gold = ( - bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) - ) - implicit_bias_sdpa_gold = bias_sdpa_gold - implicit_bias_sdpa_gold = implicit_bias_sdpa_gold.view(H, 1, 1).expand(H, S, S) - - self._test_learnable_bias_inner( - B, - H, - S, - D, - head_bias, - bias_flex, - implicit_bias_sdpa_ref, - bias_sdpa_ref, - implicit_bias_sdpa_gold, - bias_sdpa_gold, - device, - ) - - @supported_platform - def test_comparison_vs_sdpa_with_learnable_bias(self): - device = self.device - # 1-dimensional bias: - B, H, S, D = 1, 1, 256, 64 - bias = torch.randn( - 2 * S, device=device, dtype=torch.float16, requires_grad=True - ) - - bias_flex = bias.detach().clone().requires_grad_(True) - - def rel_pos_1d(score, b, h, q_idx, kv_idx): - return score + bias_flex[q_idx + kv_idx] - - bias_indices = torch.arange(S)[:, None] + torch.arange(S) - bias_sdpa_ref = bias.detach().clone().requires_grad_(True) - implicit_bias_sdpa_ref = bias_sdpa_ref[bias_indices] - bias_sdpa_gold = ( - bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) - ) - implicit_bias_sdpa_gold = bias_sdpa_gold[bias_indices] - - self._test_learnable_bias_inner( - B, - H, - S, - D, - rel_pos_1d, - bias_flex, - implicit_bias_sdpa_ref, - bias_sdpa_ref, - implicit_bias_sdpa_gold, - bias_sdpa_gold, - device, - ) - - # 2-dimensional bias: - B, H, S, D = 1, 1, 256, 64 - bias = torch.randn(S, S, device=device, dtype=torch.float16, requires_grad=True) - - bias_flex = bias.detach().clone().requires_grad_(True) - - def rel_pos_2d(score, b, h, q_idx, kv_idx): - return score + bias_flex[q_idx, kv_idx] - - bias_sdpa_ref = bias.detach().clone().requires_grad_(True) - implicit_bias_sdpa_ref = bias_sdpa_ref - bias_sdpa_gold = ( - bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) - ) - implicit_bias_sdpa_gold = bias_sdpa_gold - - self._test_learnable_bias_inner( - B, - H, - S, - D, - rel_pos_2d, - bias_flex, - implicit_bias_sdpa_ref, - bias_sdpa_ref, - implicit_bias_sdpa_gold, - bias_sdpa_gold, - device, - ) - - # 2-dimensional bias + index multiple - B, H, S, D = 1, 1, 256, 64 - bias = torch.randn(S, S, device=device, dtype=torch.float16, requires_grad=True) - - bias_flex = bias.detach().clone().requires_grad_(True) - - def rel_pos_2d(score, b, h, q_idx, kv_idx): - return score + bias_flex[q_idx][kv_idx] - - bias_sdpa_ref = bias.detach().clone().requires_grad_(True) - implicit_bias_sdpa_ref = bias_sdpa_ref - bias_sdpa_gold = ( - bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) - ) - implicit_bias_sdpa_gold = bias_sdpa_gold - - self._test_learnable_bias_inner( - B, - H, - S, - D, - rel_pos_2d, - bias_flex, - implicit_bias_sdpa_ref, - bias_sdpa_ref, - implicit_bias_sdpa_gold, - bias_sdpa_gold, - device, - ) - - # 2-dimensional bias + transposed: - B, H, S, D = 1, 1, 256, 64 - bias = torch.randn(S, S, device=device, dtype=torch.float16, requires_grad=True) - - bias_flex = bias.detach().clone().requires_grad_(True) - - def rel_pos_2d_transposed(score, b, h, q_idx, kv_idx): - return score + bias_flex[kv_idx, q_idx] - - bias_sdpa_ref = bias.detach().clone().requires_grad_(True) - implicit_bias_sdpa_ref = bias_sdpa_ref.transpose(-1, -2) - bias_sdpa_gold = ( - bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) - ) - implicit_bias_sdpa_gold = bias_sdpa_gold.transpose(-1, -2) - - self._test_learnable_bias_inner( - B, - H, - S, - D, - rel_pos_2d_transposed, - bias_flex, - implicit_bias_sdpa_ref, - bias_sdpa_ref, - implicit_bias_sdpa_gold, - bias_sdpa_gold, - device, - ) - - # 3-dimensional bias + transposed - B, H, S, D = 4, 8, 256, 64 - bias = torch.randn( - H, S, S, device=device, dtype=torch.float16, requires_grad=True - ) - - bias_flex = bias.detach().clone().requires_grad_(True) - - def rel_pos_3d_transposed(score, b, h, q_idx, kv_idx): - return score + bias_flex[h, kv_idx, q_idx] - - bias_sdpa_ref = bias.detach().clone().requires_grad_(True) - implicit_bias_sdpa_ref = bias_sdpa_ref.transpose(-1, -2) - bias_sdpa_gold = ( - bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) - ) - implicit_bias_sdpa_gold = bias_sdpa_gold.transpose(-1, -2) - - self._test_learnable_bias_inner( - B, - H, - S, - D, - rel_pos_3d_transposed, - bias_flex, - implicit_bias_sdpa_ref, - bias_sdpa_ref, - implicit_bias_sdpa_gold, - bias_sdpa_gold, - device, - ) - - def _test_learnable_bias_inner( - self, - B, - H, - S, - D, - score_mod, - bias_flex, - implicit_bias_sdpa_ref, - bias_sdpa_ref, - implicit_bias_sdpa_gold, - bias_sdpa_gold, - device, - ): - make_tensor = functools.partial( - torch.ones, - (B, H, S, D), - device=device, - dtype=torch.float16, - requires_grad=True, - ) - q_ref, k_ref, v_ref = make_tensor(), make_tensor(), make_tensor() - q_gold, k_gold, v_gold = query_key_value_clones( - q_ref, k_ref, v_ref, torch.float64 - ) - q_flex, k_flex, v_flex = query_key_value_clones(q_ref, k_ref, v_ref) - - out_ref = torch.nn.functional.scaled_dot_product_attention( - q_ref, k_ref, v_ref, attn_mask=implicit_bias_sdpa_ref - ) - out_ref.sum().backward() - out_gold = torch.nn.functional.scaled_dot_product_attention( - q_gold, k_gold, v_gold, attn_mask=implicit_bias_sdpa_gold - ) - out_gold.sum().backward() - out_flex = flex_attention(q_flex, k_flex, v_flex, score_mod=score_mod) - out_flex.sum().backward() - - name = score_mod.__name__ - for ref, flex, gold in [ - (out_ref, out_flex, out_gold), - (q_ref.grad, q_flex.grad, q_gold.grad), - (k_ref.grad, k_flex.grad, k_gold.grad), - (v_ref.grad, v_flex.grad, v_gold.grad), - (bias_sdpa_ref.grad, bias_flex.grad, bias_sdpa_gold.grad), - ]: - ref_error = rmse(ref, gold) - flex_error = rmse(flex, gold) - self.assertTrue( - ref_error * 1.2 >= flex_error, - f"{name} -> Ref error: {ref_error}, Flex eager Error: {flex_error}", - ) - instantiate_device_type_tests(TestFlexAttention, globals(), only_for=test_device) instantiate_device_type_tests(TestPagedAttention, globals(), only_for=test_device) -instantiate_device_type_tests(TestBlockMask, globals(), only_for=("cuda",)) +common_utils.instantiate_parametrized_tests(TestBlockMask) common_utils.instantiate_parametrized_tests(TestLearnableBiases) if __name__ == "__main__":