From 5211f4c1088f564cb15146e41bc592b7cf1824af Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 21 Oct 2025 16:04:47 -0700 Subject: [PATCH] [MPS] Fix SDPA fp16 overflow (#165961) Do not cast intermediate result back to lower precision data data until softmax is finished, otherwise it might produce NaN Adjust the test to use 256 as filler value rather than 64 Fixes https://github.com/pytorch/pytorch/issues/160841 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165961 Approved by: https://github.com/dcci, https://github.com/Skylion007 ghstack dependencies: #165960 --- aten/src/ATen/native/mps/operations/Attention.mm | 15 ++++++--------- test/test_transformers.py | 2 +- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/Attention.mm b/aten/src/ATen/native/mps/operations/Attention.mm index 11498ade6fd..ce571741778 100644 --- a/aten/src/ATen/native/mps/operations/Attention.mm +++ b/aten/src/ATen/native/mps/operations/Attention.mm @@ -92,13 +92,8 @@ static std::tuple sdpa_general_mps(const Tensor& query, } // upcasting to float32 if needed to improve precision when multiplying by the scale factor - if ([maskedMM dataType] != MPSDataTypeFloat32) { - maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil]; - } + maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32); maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil]; - if ([maskedMM dataType] != qTensor.dataType) { - maskedMM = [mpsGraph castTensor:maskedMM toType:qTensor.dataType name:nil]; - } if (is_causal) { auto causalMask = [mpsGraph constantWithScalar:1.0f @@ -112,7 +107,9 @@ static std::tuple sdpa_general_mps(const Tensor& query, name:nil]; } else if (attn_mask) { graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask); - maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil]; + maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM + secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType) + name:nil]; } // Account for case where all values were masked causing division by 0 in softmax (issue:#156707) @@ -133,8 +130,8 @@ static std::tuple sdpa_general_mps(const Tensor& query, graph->qTensor = qTensor; graph->kTensor = kTensor; graph->vTensor = vTensor; - graph->outputTensor = output; - graph->attnTensor = sm; + graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType); + graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType); }); auto qPlaceholder = Placeholder(cachedGraph->qTensor, query); auto kPlaceholder = Placeholder(cachedGraph->kTensor, key); diff --git a/test/test_transformers.py b/test/test_transformers.py index 83bbba37c06..6fd66feece8 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -2067,7 +2067,7 @@ class TestSDPA(NNTestCase): def test_scaled_dot_product_attention_fp16_overflow(self, device): # Regression test for https://github.com/pytorch/pytorch/issues/160841 - x = torch.full((1, 32, 23, 80), 64.0, dtype=torch.half, device=device) + x = torch.full((1, 32, 23, 80), 256.0, dtype=torch.half, device=device) y = torch.nn.functional.scaled_dot_product_attention(x, x, x) self.assertFalse(y.isnan().any().item())