mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[MPS] Fix SDPA fp16 overflow (#165961)
Do not cast intermediate result back to lower precision data data until softmax is finished, otherwise it might produce NaN Adjust the test to use 256 as filler value rather than 64 Fixes https://github.com/pytorch/pytorch/issues/160841 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165961 Approved by: https://github.com/dcci, https://github.com/Skylion007 ghstack dependencies: #165960
This commit is contained in:
parent
ad9027b80d
commit
5211f4c108
|
|
@ -92,13 +92,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
|
|||
}
|
||||
|
||||
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
|
||||
if ([maskedMM dataType] != MPSDataTypeFloat32) {
|
||||
maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil];
|
||||
}
|
||||
maskedMM = castMPSTensor(mpsGraph, maskedMM, MPSDataTypeFloat32);
|
||||
maskedMM = [mpsGraph multiplicationWithPrimaryTensor:maskedMM secondaryTensor:scaleTensor name:nil];
|
||||
if ([maskedMM dataType] != qTensor.dataType) {
|
||||
maskedMM = [mpsGraph castTensor:maskedMM toType:qTensor.dataType name:nil];
|
||||
}
|
||||
|
||||
if (is_causal) {
|
||||
auto causalMask = [mpsGraph constantWithScalar:1.0f
|
||||
|
|
@ -112,7 +107,9 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
|
|||
name:nil];
|
||||
} else if (attn_mask) {
|
||||
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
|
||||
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM
|
||||
secondaryTensor:castMPSTensor(mpsGraph, graph->maskTensor, maskedMM.dataType)
|
||||
name:nil];
|
||||
}
|
||||
|
||||
// Account for case where all values were masked causing division by 0 in softmax (issue:#156707)
|
||||
|
|
@ -133,8 +130,8 @@ static std::tuple<Tensor, Tensor> sdpa_general_mps(const Tensor& query,
|
|||
graph->qTensor = qTensor;
|
||||
graph->kTensor = kTensor;
|
||||
graph->vTensor = vTensor;
|
||||
graph->outputTensor = output;
|
||||
graph->attnTensor = sm;
|
||||
graph->outputTensor = castMPSTensor(mpsGraph, output, qTensor.dataType);
|
||||
graph->attnTensor = castMPSTensor(mpsGraph, sm, qTensor.dataType);
|
||||
});
|
||||
auto qPlaceholder = Placeholder(cachedGraph->qTensor, query);
|
||||
auto kPlaceholder = Placeholder(cachedGraph->kTensor, key);
|
||||
|
|
|
|||
|
|
@ -2067,7 +2067,7 @@ class TestSDPA(NNTestCase):
|
|||
|
||||
def test_scaled_dot_product_attention_fp16_overflow(self, device):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/160841
|
||||
x = torch.full((1, 32, 23, 80), 64.0, dtype=torch.half, device=device)
|
||||
x = torch.full((1, 32, 23, 80), 256.0, dtype=torch.half, device=device)
|
||||
y = torch.nn.functional.scaled_dot_product_attention(x, x, x)
|
||||
self.assertFalse(y.isnan().any().item())
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user