[float16]: Fix the accumulation type for dot and gemv (#152676)

Fixes #147860

Also, partially address: https://github.com/pytorch/pytorch/issues/125438

Use float32 for accumulation with float16 and and bfloat16 types

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152676
Approved by: https://github.com/malfet
This commit is contained in:
Krishna Bindumadhavan 2025-05-06 18:10:03 +00:00 committed by PyTorch MergeBot
parent 7a0781eaad
commit 08f5371571
2 changed files with 2 additions and 4 deletions

View File

@ -561,7 +561,7 @@ void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, i
opmath_t sum = 0;
const scalar_t *row_ = a + lda * i;
for (const auto j : c10::irange(m)) {
sum += x[j * incx] * row_[j];
sum += static_cast<opmath_t>(x[j * incx]) * static_cast<opmath_t>(row_[j]);
}
if (beta == scalar_t(0)) {
y[i * incy] = alpha * sum;
@ -692,7 +692,7 @@ scalar_t dot_impl(int64_t n, const scalar_t* x, int64_t incx, const scalar_t* y,
incx = 1;
incy = 1;
}
return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<scalar_t>{});
return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<at::opmath_type<scalar_t>>{});
}
template <>

View File

@ -7003,8 +7003,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
torch.half))
@dtypes(torch.bfloat16, torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_addmv(self, device, dtype):
if IS_ARM64 and device == 'cpu' and dtype == torch.float16:
raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438")
# have to use torch.randn(...).to(bfloat16) instead of
# torch.randn(..., dtype=bfloat16). randn does not support
# bfloat16 yet.