mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[float16]: Fix the accumulation type for dot and gemv (#152676)
Fixes #147860 Also, partially address: https://github.com/pytorch/pytorch/issues/125438 Use float32 for accumulation with float16 and and bfloat16 types Pull Request resolved: https://github.com/pytorch/pytorch/pull/152676 Approved by: https://github.com/malfet
This commit is contained in:
parent
7a0781eaad
commit
08f5371571
|
|
@ -561,7 +561,7 @@ void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, i
|
|||
opmath_t sum = 0;
|
||||
const scalar_t *row_ = a + lda * i;
|
||||
for (const auto j : c10::irange(m)) {
|
||||
sum += x[j * incx] * row_[j];
|
||||
sum += static_cast<opmath_t>(x[j * incx]) * static_cast<opmath_t>(row_[j]);
|
||||
}
|
||||
if (beta == scalar_t(0)) {
|
||||
y[i * incy] = alpha * sum;
|
||||
|
|
@ -692,7 +692,7 @@ scalar_t dot_impl(int64_t n, const scalar_t* x, int64_t incx, const scalar_t* y,
|
|||
incx = 1;
|
||||
incy = 1;
|
||||
}
|
||||
return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<scalar_t>{});
|
||||
return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies<at::opmath_type<scalar_t>>{});
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
|
|||
|
|
@ -7003,8 +7003,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
|||
torch.half))
|
||||
@dtypes(torch.bfloat16, torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble)
|
||||
def test_addmv(self, device, dtype):
|
||||
if IS_ARM64 and device == 'cpu' and dtype == torch.float16:
|
||||
raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438")
|
||||
# have to use torch.randn(...).to(bfloat16) instead of
|
||||
# torch.randn(..., dtype=bfloat16). randn does not support
|
||||
# bfloat16 yet.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user