mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PyTorch] SDPA decomp: actually use attn_mask (#117579)
Summary: Need to pass this along Test Plan: ``` cd ~/fbsource/fbcode/executorch/backends/xnnpack/test buck test fbcode//mode/dev-nosan :test_xnnpack_ops -- test_fp32_sdpa buck run fbcode//mode/dev-nosan :test_xnnpack_models -- executorch.backends.xnnpack.test.models.llama2_et_example.TestLlama2ETExample.test_fp32 ``` Reviewed By: larryliu0820 Differential Revision: D52812369 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117579 Approved by: https://github.com/larryliu0820
This commit is contained in:
parent
1deb75b584
commit
e2830e6328
|
|
@ -927,9 +927,9 @@ class DecompOneOffTests(TestCase):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, query_layer, key_layer, value_layer):
|
||||
def forward(self, query_layer, key_layer, value_layer, mask=None, is_causal=True):
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_layer, key_layer, value_layer, None, dropout_p=0.0, is_causal=True
|
||||
query_layer, key_layer, value_layer, attn_mask=mask, dropout_p=0.0, is_causal=is_causal
|
||||
)
|
||||
return attn_output
|
||||
|
||||
|
|
@ -937,22 +937,25 @@ class DecompOneOffTests(TestCase):
|
|||
query_layer = torch.randn(1, 128, 100, 64, device=device)
|
||||
key_layer = torch.randn(1, 128, 100, 64, device=device)
|
||||
value_layer = torch.randn(1, 128, 100, 64, device=device)
|
||||
masks = [None, torch.randn(1, 1, 100, 100, device=device)]
|
||||
|
||||
attention = ScaledDotProductAttention()
|
||||
fx_g = make_fx(
|
||||
attention,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
|
||||
]
|
||||
),
|
||||
)(query_layer, key_layer, value_layer)
|
||||
for mask in masks:
|
||||
is_causal = mask is None
|
||||
attention = ScaledDotProductAttention()
|
||||
fx_g = make_fx(
|
||||
attention,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
|
||||
]
|
||||
),
|
||||
)(query_layer, key_layer, value_layer, mask, is_causal)
|
||||
|
||||
compiled_res = fx_g(query_layer, key_layer, value_layer)
|
||||
eager_res = F.scaled_dot_product_attention(
|
||||
query_layer, key_layer, value_layer, None, dropout_p=0.0, is_causal=True
|
||||
)
|
||||
self.assertTrue(torch.allclose(compiled_res, eager_res, atol=1e-6, rtol=1e-5))
|
||||
compiled_res = fx_g(query_layer, key_layer, value_layer, mask, is_causal)
|
||||
eager_res = F.scaled_dot_product_attention(
|
||||
query_layer, key_layer, value_layer, attn_mask=mask, dropout_p=0.0, is_causal=is_causal
|
||||
)
|
||||
self.assertTrue(torch.allclose(compiled_res, eager_res, atol=1e-6, rtol=1e-5))
|
||||
|
||||
|
||||
instantiate_device_type_tests(DecompOneOffTests, globals())
|
||||
|
|
|
|||
|
|
@ -4333,7 +4333,14 @@ def scaled_dot_product_flash_attention_for_cpu(
|
|||
)
|
||||
|
||||
output, attn = aten._scaled_dot_product_attention_math.default(
|
||||
query, key, value, None, dropout_p, is_causal, None, scale=scale
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
dropout_mask=None,
|
||||
scale=scale,
|
||||
)
|
||||
# Why this change?
|
||||
# In pre-dispatch export scaled_dot_product_attention is executed via
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user