diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index 20be0d6fe01..e06afddd05a 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -457,9 +457,24 @@ void gemm( return; } #endif + // for the fallback path, first compute gemm with beta = 0, + // and then add c in full precision. + int64_t c_size = n * m; + std::vector float_c(c_size, 0.f); gemm_no_downcast_stub( at::kCPU, at::kBFloat16, - transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m); + for (const auto j : c10::irange(n)) { + for (const auto i : c10::irange(m)) { + auto offset = j * ldc + i; + // beta == 0 won't propagate NaN from C + if (beta == 0.f) { + c[offset] = float_c[j * m + i]; + } else { + c[offset] = beta * c[offset] + float_c[j * m + i]; + } + } + } } void gemm( @@ -478,9 +493,24 @@ void gemm( return; } #endif + // for the fallback path, first compute gemm with beta = 0, + // and then add c in full precision. + int64_t c_size = n * m; + std::vector float_c(c_size, 0.f); gemm_no_downcast_stub( at::kCPU, at::kHalf, - transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m); + for (const auto j : c10::irange(n)) { + for (const auto i : c10::irange(m)) { + auto offset = j * ldc + i; + // beta == 0 won't propagate NaN from C + if (beta == 0.f) { + c[offset] = float_c[j * m + i]; + } else { + c[offset] = beta * c[offset] + float_c[j * m + i]; + } + } + } } void gemm(