[FlexAttention] Improve error msg for embedding < 16 (#147765)

flex_attention uses tl.dot, which [does not support embedding < 16](https://github.com/triton-lang/triton/issues/2266) on input shapes. This PR adds explicit error message for users who are prototyping with small tensors.

Fixes #147701

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147765
Approved by: https://github.com/drisspg
This commit is contained in:
Boyuan Feng 2025-02-26 17:06:35 +00:00 committed by PyTorch MergeBot
parent ac926f81cc
commit ba9ed856e0
3 changed files with 37 additions and 3 deletions

View File

@ -2133,10 +2133,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
torch.compile(flex_attention)(q, k, v, score_mod, block_mask=block_mask)
@supported_platform
@common_utils.parametrize("head_dim", [13, 24, 94, 121])
@common_utils.parametrize("head_dim", [17, 24, 94, 121])
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_non_pow_2_headdim(self, dtype, head_dim):
self.run_test(_rel_bias, torch.float16, B, H, S, head_dim, B, H, S, head_dim)
self.run_test(_rel_bias, dtype, B, H, S, head_dim, B, H, S, head_dim)
@supported_platform
def test_GQA_causal_mask(self):
@ -3715,6 +3715,31 @@ class GraphModule(torch.nn.Module):
attn_output = mod(q, k, v, mask)
self.assertEqual(attn_output.device, torch.device("cuda:1"))
@supported_platform
def test_validate_small_embedding_size_error_message(self):
# eager support for small embedding size
q, k, v = [torch.randn(2, 2, 128, 8, device="cuda") for _ in range(3)]
flex_attention(q, k, v)
# compiled cpu support for small embedding size
q, k, v = [torch.randn(2, 2, 128, 8, device="cpu") for _ in range(3)]
flex_attention(q, k, v)
# compiled gpu kernel does not support small embedding size
q, k, v = [torch.randn(2, 2, 128, 8, device="cuda") for _ in range(3)]
compiled_fa = torch.compile(flex_attention)
with self.assertRaisesRegex(
torch._inductor.exc.InductorError,
"NYI: embedding dimension of the query, key, and value must be "
"at least 16 but got E=8 and Ev=8",
):
compiled_fa(q, k, v)
# compiled gpu kernel supports large embedding size
q, k, v = [torch.randn(2, 2, 128, 16, device="cuda") for _ in range(3)]
compiled_fa = torch.compile(flex_attention)
class TestBlockMask(InductorTestCase):
@supported_platform

View File

@ -1070,7 +1070,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
self.run_test_with_paged_attention(score_mod_scale, dtype)
@supported_platform
@common_utils.parametrize("head_dim", [13, 24, 94, 121])
@common_utils.parametrize("head_dim", [17, 24, 94, 121])
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_non_pow_2_headdim(self, dtype, head_dim):
self.run_test(_rel_bias, dtype, B, Hq, S, head_dim, B, Hkv, S, head_dim)

View File

@ -1255,6 +1255,15 @@ def flex_attention(
)
# below is cuda path if device is not cpu
# tl.dot does not support embedding size less than 16
small_dqk = V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-1], 16))
small_dv = V.graph.sizevars.evaluate_expr(sympy.Lt(value.get_size()[-1], 16))
if small_dqk or small_dv:
raise NotImplementedError(
f"NYI: embedding dimension of the query, key, and value must be "
f"at least 16 but got E={query.get_size()[-1]} and Ev={value.get_size()[-1]}"
)
(
_, # q_length
_, # kv_length