[MPS] Fix SDPA fp16 overflow (#165961)

Do not cast intermediate result back to lower precision data data until
softmax is finished, otherwise it might produce NaN

Adjust the test to use 256 as filler value rather than 64

Fixes https://github.com/pytorch/pytorch/issues/160841
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165961
Approved by: https://github.com/dcci, https://github.com/Skylion007
ghstack dependencies: #165960
This commit is contained in:
Nikita Shulga 2025-10-21 16:04:47 -07:00 committed by PyTorch MergeBot
parent ad9027b80d
commit 5211f4c108
2 changed files with 7 additions and 10 deletions

View File

@ -92,13 +92,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
}
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
if ([maskedMM dataType] != MPSDataTypeFloat32) {
maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil];
}
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
if ([maskedMM dataType] != qTensor.dataType) {
maskedMM = [mpsGraph castTensor:maskedMM toType:qTensor.dataType name:nil];
}
if (is_causal) {
auto causalMask = [mpsGraph constantWithScalar:1.0f
@ -112,7 +107,9 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
name:nil];
} else if (attn_mask) {
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
name:nil];
}
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
@ -133,8 +130,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
graph->qTensor = qTensor;
graph->kTensor = kTensor;
graph->vTensor = vTensor;
graph->outputTensor = output;
graph->attnTensor = sm;
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
});
auto qPlaceholder = Placeholder(cachedGraph->qTensor, query);
auto kPlaceholder = Placeholder(cachedGraph->kTensor, key);

View File

@ -2067,7 +2067,7 @@ class TestSDPA(NNTestCase):
def test_scaled_dot_product_attention_fp16_overflow(self, device):
# Regression test for https://github.com/pytorch/pytorch/issues/160841
x = torch.full((1, 32, 23, 80), 64.0, dtype=torch.half, device=device)
x = torch.full((1, 32, 23, 80), 256.0, dtype=torch.half, device=device)
y = torch.nn.functional.scaled_dot_product_attention(x, x, x)
self.assertFalse(y.isnan().any().item())