mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[FlexAttention] Improve error msg for embedding < 16 (#147765)
flex_attention uses tl.dot, which [does not support embedding < 16](https://github.com/triton-lang/triton/issues/2266) on input shapes. This PR adds explicit error message for users who are prototyping with small tensors. Fixes #147701 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147765 Approved by: https://github.com/drisspg
This commit is contained in:
parent
ac926f81cc
commit
ba9ed856e0
|
|
@ -2133,10 +2133,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
torch.compile(flex_attention)(q, k, v, score_mod, block_mask=block_mask)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("head_dim", [13, 24, 94, 121])
|
||||
@common_utils.parametrize("head_dim", [17, 24, 94, 121])
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
def test_non_pow_2_headdim(self, dtype, head_dim):
|
||||
self.run_test(_rel_bias, torch.float16, B, H, S, head_dim, B, H, S, head_dim)
|
||||
self.run_test(_rel_bias, dtype, B, H, S, head_dim, B, H, S, head_dim)
|
||||
|
||||
@supported_platform
|
||||
def test_GQA_causal_mask(self):
|
||||
|
|
@ -3715,6 +3715,31 @@ class GraphModule(torch.nn.Module):
|
|||
attn_output = mod(q, k, v, mask)
|
||||
self.assertEqual(attn_output.device, torch.device("cuda:1"))
|
||||
|
||||
@supported_platform
|
||||
def test_validate_small_embedding_size_error_message(self):
|
||||
# eager support for small embedding size
|
||||
q, k, v = [torch.randn(2, 2, 128, 8, device="cuda") for _ in range(3)]
|
||||
flex_attention(q, k, v)
|
||||
|
||||
# compiled cpu support for small embedding size
|
||||
q, k, v = [torch.randn(2, 2, 128, 8, device="cpu") for _ in range(3)]
|
||||
flex_attention(q, k, v)
|
||||
|
||||
# compiled gpu kernel does not support small embedding size
|
||||
q, k, v = [torch.randn(2, 2, 128, 8, device="cuda") for _ in range(3)]
|
||||
compiled_fa = torch.compile(flex_attention)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._inductor.exc.InductorError,
|
||||
"NYI: embedding dimension of the query, key, and value must be "
|
||||
"at least 16 but got E=8 and Ev=8",
|
||||
):
|
||||
compiled_fa(q, k, v)
|
||||
|
||||
# compiled gpu kernel supports large embedding size
|
||||
q, k, v = [torch.randn(2, 2, 128, 16, device="cuda") for _ in range(3)]
|
||||
compiled_fa = torch.compile(flex_attention)
|
||||
|
||||
|
||||
class TestBlockMask(InductorTestCase):
|
||||
@supported_platform
|
||||
|
|
|
|||
|
|
@ -1070,7 +1070,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||
self.run_test_with_paged_attention(score_mod_scale, dtype)
|
||||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("head_dim", [13, 24, 94, 121])
|
||||
@common_utils.parametrize("head_dim", [17, 24, 94, 121])
|
||||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
def test_non_pow_2_headdim(self, dtype, head_dim):
|
||||
self.run_test(_rel_bias, dtype, B, Hq, S, head_dim, B, Hkv, S, head_dim)
|
||||
|
|
|
|||
|
|
@ -1255,6 +1255,15 @@ def flex_attention(
|
|||
)
|
||||
|
||||
# below is cuda path if device is not cpu
|
||||
# tl.dot does not support embedding size less than 16
|
||||
small_dqk = V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-1], 16))
|
||||
small_dv = V.graph.sizevars.evaluate_expr(sympy.Lt(value.get_size()[-1], 16))
|
||||
if small_dqk or small_dv:
|
||||
raise NotImplementedError(
|
||||
f"NYI: embedding dimension of the query, key, and value must be "
|
||||
f"at least 16 but got E={query.get_size()[-1]} and Ev={value.get_size()[-1]}"
|
||||
)
|
||||
|
||||
(
|
||||
_, # q_length
|
||||
_, # kv_length
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user