From cfbd99fdfd7282c8969f123d5819a47d408ce78a Mon Sep 17 00:00:00 2001 From: Cyrus Daruwala Date: Tue, 27 May 2025 17:43:21 +0000 Subject: [PATCH] [Pytorch] Add option to CPU Blas GEMM to avoid output downcast (#154012) Summary: Dot product for a single output element consists of 3 steps (both input vectors have elements of type scalar_t): 1. elementwise vector multiply (scalar_t x scalar_t -> opmath_t) 2. vector reduction to a scalar value (opmath_t -> opmath_t) 3. optional downcast if opmath_t != out_t The current blas kernel performs steps 1 and 2 correctly, but for step 3, it will always downcast to scalar_t even when opmath_t == output_t (and then do an upcast back to output_t), which results in precision loss. This diff fixes the precision loss in the BlasKernel Test Plan: Attention CI passes Differential Revision: D75023858 topic: not user facing Pull Request resolved: https://github.com/pytorch/pytorch/pull/154012 Approved by: https://github.com/Valentine233, https://github.com/aditew01, https://github.com/CaoE, https://github.com/drisspg --- aten/src/ATen/native/CPUBlas.cpp | 11 +++--- aten/src/ATen/native/CPUBlas.h | 12 +++++++ aten/src/ATen/native/cpu/BlasKernel.cpp | 47 ++++++++++++++++++------- 3 files changed, 53 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index 249ba065f2e..3b18832b052 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -135,6 +135,7 @@ CBLAS_TRANSPOSE to_apple_accelerate_transpose(TransposeType trans) { } // namespace (anonymous) DEFINE_DISPATCH(gemm_stub); +DEFINE_DISPATCH(gemm_no_downcast_stub); void gemm( TransposeType transa, TransposeType transb, @@ -452,18 +453,18 @@ void gemm( // for the fallback path, first compute gemm with beta = 0, // and then add c in full precision. int64_t c_size = n * m; - std::vector bfloat_c(c_size, 0.f); - gemm_stub( + std::vector float_c(c_size, 0.f); + gemm_no_downcast_stub( at::kCPU, at::kBFloat16, - transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, bfloat_c.data(), m); + transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float_c.data(), m); for (const auto j : c10::irange(n)) { for (const auto i : c10::irange(m)) { auto offset = j * ldc + i; // beta == 0 won't propagate NaN from C if (beta == 0.f) { - c[offset] = c10::convert(bfloat_c[j * m + i]); + c[offset] = float_c[j * m + i]; } else { - c[offset] = beta * c[offset] + c10::convert(bfloat_c[j * m + i]); + c[offset] = beta * c[offset] + float_c[j * m + i]; } } } diff --git a/aten/src/ATen/native/CPUBlas.h b/aten/src/ATen/native/CPUBlas.h index f5a8b9018dc..95d11903dc7 100644 --- a/aten/src/ATen/native/CPUBlas.h +++ b/aten/src/ATen/native/CPUBlas.h @@ -29,6 +29,18 @@ using gemm_fn = void(*)( DECLARE_DISPATCH(gemm_fn, gemm_stub) +using gemm_no_downcast_fn = void(*)( + at::ScalarType type, + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const Scalar& alpha, + const void *a, int64_t lda, + const void *b, int64_t ldb, + const Scalar& beta, + void *c, int64_t ldc); + +DECLARE_DISPATCH(gemm_no_downcast_fn, gemm_no_downcast_stub) + template void gemm( TransposeType transa, TransposeType transb, diff --git a/aten/src/ATen/native/cpu/BlasKernel.cpp b/aten/src/ATen/native/cpu/BlasKernel.cpp index 9d2509f1f24..82e7dfd213f 100644 --- a/aten/src/ATen/native/cpu/BlasKernel.cpp +++ b/aten/src/ATen/native/cpu/BlasKernel.cpp @@ -99,7 +99,7 @@ auto sum(int64_t N, Func f) { return partial_sums[0]; } -template +template std::enable_if_t, void> gemm_notrans_( int64_t m, @@ -111,7 +111,7 @@ gemm_notrans_( const scalar_t* b, int64_t ldb, opmath_t beta, - scalar_t* c, + out_t* c, int64_t ldc) { // c *= beta scale_(m, n, beta, c, ldc); @@ -135,7 +135,7 @@ gemm_notrans_( } // std::is_same || std::is_same -template +template std::enable_if_t, void> gemm_notrans_( int64_t m, @@ -147,7 +147,7 @@ gemm_notrans_( const scalar_t* b, int64_t ldb, opmath_t beta, - scalar_t* c, + out_t* c, int64_t ldc) { // c += alpha * (a @ b) for (const auto i : c10::irange(m)) { @@ -165,7 +165,7 @@ gemm_notrans_( } } -template +template void gemm_transa_( TransposeType transa, int64_t m, int64_t n, int64_t k, @@ -173,7 +173,7 @@ void gemm_transa_( const scalar_t *a, int64_t lda, const scalar_t *b, int64_t ldb, opmath_t beta, - scalar_t *c, int64_t ldc) { + out_t *c, int64_t ldc) { // c = alpha * (a.T @ b) + beta * c const scalar_t *a_ = a; for (const auto i : c10::irange(m)) { @@ -225,6 +225,7 @@ void gemm_transb_impl( } } +// in this case, scalar_t == opmath_t == out_t so out_t template param is not needed template std::enable_if_t, void> gemm_transb_( @@ -247,7 +248,7 @@ gemm_transb_( } // std::is_same || std::is_same -template +template std::enable_if_t, void> gemm_transb_( TransposeType transb, @@ -260,7 +261,7 @@ gemm_transb_( const scalar_t* b, int64_t ldb, opmath_t beta, - scalar_t* c, + out_t* c, int64_t ldc) { // We need to calculate full-precision dot products for correctness; // users notice error accumulation with reduced-width types (e.g., @@ -304,7 +305,7 @@ gemm_transb_( } } -template +template void gemm_transab_( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, @@ -312,7 +313,7 @@ void gemm_transab_( const scalar_t *a, int64_t lda, const scalar_t *b, int64_t ldb, opmath_t beta, - scalar_t *c, int64_t ldc) { + out_t *c, int64_t ldc) { // c = beta * c + alpha * (a.T @ b.T) for (const auto i : c10::irange(m)) { for (const auto j : c10::irange(n)) { @@ -436,7 +437,7 @@ void gemm_transa_( } #endif // !defined(C10_MOBILE) -template +template void gemm_core_( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, @@ -444,7 +445,7 @@ void gemm_core_( const scalar_t *a, int64_t lda, const scalar_t *b, int64_t ldb, opmath_t beta, - scalar_t *c, int64_t ldc) { + out_t *c, int64_t ldc) { if (transa == TransposeType::NoTranspose && transb == TransposeType::NoTranspose) { return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); @@ -493,6 +494,27 @@ void cpublas_gemm_impl( }); } +void cpublas_gemm_no_downcast_impl( + at::ScalarType type, + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const Scalar& alpha, + const void *a, int64_t lda, + const void *b, int64_t ldb, + const Scalar& beta, + void *c, int64_t ldc) { +_AT_DISPATCH_GEMM_TYPES(type, "cpublas_gemm_no_downcast_impl", [&]{ + using opmath_t = at::opmath_type; + gemm_core_( + transa, transb, m, n, k, + alpha.to(), + static_cast(a), lda, + static_cast(b), ldb, + beta.to(), + static_cast(c), ldc); + }); +} + void cpublas_axpy_impl(at::ScalarType type, int64_t n, const Scalar& _a, const void *_x, int64_t incx, void *_y, int64_t incy){ if (type == at::kBool) { auto a = _a.to(); @@ -530,6 +552,7 @@ void cpublas_copy_impl(at::ScalarType type, int64_t n, const void *_x, int64_t i REGISTER_DISPATCH(cpublas::gemm_stub, &cpublas::cpublas_gemm_impl) +REGISTER_DISPATCH(cpublas::gemm_no_downcast_stub, &cpublas::cpublas_gemm_no_downcast_impl) REGISTER_DISPATCH(cpublas::axpy_stub, &cpublas::cpublas_axpy_impl) REGISTER_DISPATCH(cpublas::copy_stub, &cpublas::cpublas_copy_impl)