From 02c83f13348631d80aa23f57aaff6b7d1223bbdd Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 2 Sep 2025 14:06:36 -0700 Subject: [PATCH] [BLAS] Avoid downcasts for fp16fp16->fp32 BLAS (#161999) Followup after https://github.com/pytorch/pytorch/pull/154012 Fixes CPU part of https://github.com/pytorch/pytorch/issues/160841 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161999 Approved by: https://github.com/drisspg --- aten/src/ATen/native/CPUBlas.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index b16c1ef04fa..e06afddd05a 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -496,18 +496,18 @@ void gemm( // for the fallback path, first compute gemm with beta = 0, // and then add c in full precision. int64_t c_size = n * m; - std::vector float16_c(c_size, 0.f); - gemm_stub( + std::vector float_c(c_size, 0.f); + gemm_no_downcast_stub( at::kCPU, at::kHalf, - transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float16_c.data(), m); + transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m); for (const auto j : c10::irange(n)) { for (const auto i : c10::irange(m)) { auto offset = j * ldc + i; // beta == 0 won't propagate NaN from C if (beta == 0.f) { - c[offset] = c10::convert(float16_c[j * m + i]); + c[offset] = float_c[j * m + i]; } else { - c[offset] = beta * c[offset] + c10::convert(float16_c[j * m + i]); + c[offset] = beta * c[offset] + float_c[j * m + i]; } } }