From afa6e5604d78b447aca3e30d9843732c1ee26885 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Sep 2025 19:56:48 +0000 Subject: [PATCH] Revert "[BE] Cleanup stale comments/copy from `gemm` (#162001)" This reverts commit b40d9432be44a6b5974ee62e7d19c3c61c5ece37. Reverted https://github.com/pytorch/pytorch/pull/162001 on behalf of https://github.com/jeanschmidt due to break a few internal tests ([comment](https://github.com/pytorch/pytorch/pull/161999#issuecomment-3255381925)) --- aten/src/ATen/native/CPUBlas.cpp | 34 ++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index 20be0d6fe01..e06afddd05a 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -457,9 +457,24 @@ void gemm( return; } #endif + // for the fallback path, first compute gemm with beta = 0, + // and then add c in full precision. + int64_t c_size = n * m; + std::vector float_c(c_size, 0.f); gemm_no_downcast_stub( at::kCPU, at::kBFloat16, - transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m); + for (const auto j : c10::irange(n)) { + for (const auto i : c10::irange(m)) { + auto offset = j * ldc + i; + // beta == 0 won't propagate NaN from C + if (beta == 0.f) { + c[offset] = float_c[j * m + i]; + } else { + c[offset] = beta * c[offset] + float_c[j * m + i]; + } + } + } } void gemm( @@ -478,9 +493,24 @@ void gemm( return; } #endif + // for the fallback path, first compute gemm with beta = 0, + // and then add c in full precision. + int64_t c_size = n * m; + std::vector float_c(c_size, 0.f); gemm_no_downcast_stub( at::kCPU, at::kHalf, - transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m); + for (const auto j : c10::irange(n)) { + for (const auto i : c10::irange(m)) { + auto offset = j * ldc + i; + // beta == 0 won't propagate NaN from C + if (beta == 0.f) { + c[offset] = float_c[j * m + i]; + } else { + c[offset] = beta * c[offset] + float_c[j * m + i]; + } + } + } } void gemm(