[PyTorch] SDPA decomp: actually use attn_mask (#117579)

Summary: Need to pass this along

Test Plan:
```
cd ~/fbsource/fbcode/executorch/backends/xnnpack/test
buck test fbcode//mode/dev-nosan :test_xnnpack_ops -- test_fp32_sdpa
buck run fbcode//mode/dev-nosan :test_xnnpack_models -- executorch.backends.xnnpack.test.models.llama2_et_example.TestLlama2ETExample.test_fp32
```

Reviewed By: larryliu0820

Differential Revision: D52812369

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117579
Approved by: https://github.com/larryliu0820
This commit is contained in:
Digant Desai 2024-01-17 10:26:43 +00:00 committed by PyTorch MergeBot
parent 1deb75b584
commit e2830e6328
2 changed files with 27 additions and 17 deletions

View File

@ -927,9 +927,9 @@ class DecompOneOffTests(TestCase):
def __init__(self):
super().__init__()
def forward(self, query_layer, key_layer, value_layer):
def forward(self, query_layer, key_layer, value_layer, mask=None, is_causal=True):
attn_output = F.scaled_dot_product_attention(
query_layer, key_layer, value_layer, None, dropout_p=0.0, is_causal=True
query_layer, key_layer, value_layer, attn_mask=mask, dropout_p=0.0, is_causal=is_causal
)
return attn_output
@ -937,22 +937,25 @@ class DecompOneOffTests(TestCase):
query_layer = torch.randn(1, 128, 100, 64, device=device)
key_layer = torch.randn(1, 128, 100, 64, device=device)
value_layer = torch.randn(1, 128, 100, 64, device=device)
masks = [None, torch.randn(1, 1, 100, 100, device=device)]
attention = ScaledDotProductAttention()
fx_g = make_fx(
attention,
decomposition_table=get_decompositions(
[
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
]
),
)(query_layer, key_layer, value_layer)
for mask in masks:
is_causal = mask is None
attention = ScaledDotProductAttention()
fx_g = make_fx(
attention,
decomposition_table=get_decompositions(
[
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
]
),
)(query_layer, key_layer, value_layer, mask, is_causal)
compiled_res = fx_g(query_layer, key_layer, value_layer)
eager_res = F.scaled_dot_product_attention(
query_layer, key_layer, value_layer, None, dropout_p=0.0, is_causal=True
)
self.assertTrue(torch.allclose(compiled_res, eager_res, atol=1e-6, rtol=1e-5))
compiled_res = fx_g(query_layer, key_layer, value_layer, mask, is_causal)
eager_res = F.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attn_mask=mask, dropout_p=0.0, is_causal=is_causal
)
self.assertTrue(torch.allclose(compiled_res, eager_res, atol=1e-6, rtol=1e-5))
instantiate_device_type_tests(DecompOneOffTests, globals())

View File

@ -4333,7 +4333,14 @@ def scaled_dot_product_flash_attention_for_cpu(
)
output, attn = aten._scaled_dot_product_attention_math.default(
query, key, value, None, dropout_p, is_causal, None, scale=scale
query,
key,
value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
dropout_mask=None,
scale=scale,
)
# Why this change?
# In pre-dispatch export scaled_dot_product_attention is executed via