diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index e06afddd05a..20be0d6fe01 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -457,24 +457,9 @@ void gemm( return; } #endif - // for the fallback path, first compute gemm with beta = 0, - // and then add c in full precision. - int64_t c_size = n * m; - std::vector float_c(c_size, 0.f); gemm_no_downcast_stub( at::kCPU, at::kBFloat16, - transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m); - for (const auto j : c10::irange(n)) { - for (const auto i : c10::irange(m)) { - auto offset = j * ldc + i; - // beta == 0 won't propagate NaN from C - if (beta == 0.f) { - c[offset] = float_c[j * m + i]; - } else { - c[offset] = beta * c[offset] + float_c[j * m + i]; - } - } - } + transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void gemm( @@ -493,24 +478,9 @@ void gemm( return; } #endif - // for the fallback path, first compute gemm with beta = 0, - // and then add c in full precision. - int64_t c_size = n * m; - std::vector float_c(c_size, 0.f); gemm_no_downcast_stub( at::kCPU, at::kHalf, - transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m); - for (const auto j : c10::irange(n)) { - for (const auto i : c10::irange(m)) { - auto offset = j * ldc + i; - // beta == 0 won't propagate NaN from C - if (beta == 0.f) { - c[offset] = float_c[j * m + i]; - } else { - c[offset] = beta * c[offset] + float_c[j * m + i]; - } - } - } + transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void gemm(