diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index b16c1ef04fa..e06afddd05a 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -496,18 +496,18 @@ void gemm( // for the fallback path, first compute gemm with beta = 0, // and then add c in full precision. int64_t c_size = n * m; - std::vector float16_c(c_size, 0.f); - gemm_stub( + std::vector float_c(c_size, 0.f); + gemm_no_downcast_stub( at::kCPU, at::kHalf, - transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float16_c.data(), m); + transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m); for (const auto j : c10::irange(n)) { for (const auto i : c10::irange(m)) { auto offset = j * ldc + i; // beta == 0 won't propagate NaN from C if (beta == 0.f) { - c[offset] = c10::convert(float16_c[j * m + i]); + c[offset] = float_c[j * m + i]; } else { - c[offset] = beta * c[offset] + c10::convert(float16_c[j * m + i]); + c[offset] = beta * c[offset] + float_c[j * m + i]; } } }