diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index 1b5fd85a83f..654a226a85e 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -90,6 +90,8 @@ namespace at::native { #if !defined(C10_MOBILE) DEFINE_DISPATCH(fp16_gemv_trans_stub); DEFINE_DISPATCH(bf16_gemv_trans_stub); +DEFINE_DISPATCH(fp16_dot_stub); +DEFINE_DISPATCH(bf16_dot_stub); #endif // !defined(C10_MOBILE) namespace blas_impl { @@ -120,6 +122,15 @@ void fp16_gemv_trans( fp16_gemv_trans_stub(kCPU, m, n, alpha, a, lda, x, incx, beta, y, incy); } +static float fp16_dot( + const int64_t n, + const Half* x, + const int64_t incx, + const Half* y, + const int64_t incy) { + return fp16_dot_stub(kCPU, n, x, incx, y, incy); +} + #endif // !defined(C10_MOBILE) #if defined(__aarch64__) && !defined(C10_MOBILE) @@ -384,6 +395,16 @@ void gemv_fast_path( y, *incy); } + +static float bf16_dot( + const int64_t n, + const BFloat16* x, + const int64_t incx, + const BFloat16* y, + const int64_t incy) { + return bf16_dot_stub(kCPU, n, x, incx, y, incy); +} + #if !defined(__aarch64__) // Currently, only fp16_gemv_trans is built for non-aarch64. template <> @@ -695,6 +716,34 @@ c10::complex dot_impl(int64_t n, const c10::complex* x, int64_t in return dot_impl_floating(n, x, incx, y, incy); } +template <> +Half dot_impl(int64_t n, const Half* x, int64_t incx, const Half* y, int64_t incy) { + if (n == 1) { + incx = 1; + incy = 1; + } +#if !defined(C10_MOBILE) + if (incx == 1 && incy == 1) { + return blas_impl::fp16_dot(n, x, incx, y, incy); + } +#endif // !defined(C10_MOBILE) + return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies{}); +} + +template <> +BFloat16 dot_impl(int64_t n, const BFloat16* x, int64_t incx, const BFloat16* y, int64_t incy) { + if (n == 1) { + incx = 1; + incy = 1; + } +#if !defined(C10_MOBILE) + if (incx == 1 && incy == 1) { + return blas_impl::bf16_dot(n, x, incx, y, incy); + } +#endif // !defined(C10_MOBILE) + return blas_impl::dot_naive(n, x, incx, y, incy, std::multiplies{}); +} + namespace { template struct vdot_op { @@ -721,7 +770,7 @@ scalar_t vdot_impl(int64_t n, const scalar_t* x, int64_t incx, const scalar_t* y #endif } -// Skip reinstantiating the explicitly specialized types `float` and `double`. +// Skip reinstantiating the explicitly specialized types `float`, `double`, `half` & `bfloat16`. #define INSTANTIATE_DOT_IMPL(scalar_t) \ template scalar_t dot_impl( \ int64_t n, const scalar_t * x, int64_t incx, const scalar_t * y, int64_t incy); @@ -730,8 +779,6 @@ INSTANTIATE_DOT_IMPL(int8_t) INSTANTIATE_DOT_IMPL(int16_t) INSTANTIATE_DOT_IMPL(int) INSTANTIATE_DOT_IMPL(int64_t) -INSTANTIATE_DOT_IMPL(c10::Half) -INSTANTIATE_DOT_IMPL(c10::BFloat16) #define INSTANTIATE_VDOT_IMPL(scalar_t) \ template scalar_t vdot_impl( \ diff --git a/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp b/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp index 8b60eda608e..142a65ac2b3 100644 --- a/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp +++ b/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp @@ -475,12 +475,35 @@ void bf16_gemv_trans( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && alpha == 1.0 && beta == 0.0); return bf16_gemv_trans_fp32_arith_by_dot_products(m, n, a, lda, x, y, incy); } + +float fp16_dot( + const int64_t n, + const at::Half* x, + const int64_t incx, + const at::Half* y, + const int64_t incy) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && incy == 1); + return fp16_dot_with_fp32_arith(x, y, n); +} + +float bf16_dot( + const int64_t n, + const at::BFloat16* x, + const int64_t incx, + const at::BFloat16* y, + const int64_t incy) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(incx == 1 && incy == 1); + return bf16_dot_with_fp32_arith(x, y, n); +} + #endif // !defined(C10_MOBILE) } // namespace CPU_CAPABILITY #if !defined(C10_MOBILE) REGISTER_DISPATCH(fp16_gemv_trans_stub, &fp16_gemv_trans) REGISTER_DISPATCH(bf16_gemv_trans_stub, &bf16_gemv_trans) +REGISTER_DISPATCH(fp16_dot_stub, &fp16_dot) +REGISTER_DISPATCH(bf16_dot_stub, &bf16_dot) #endif //!defined(C10_MOBILE) } // namespace at::native diff --git a/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h b/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h index cd7a74cd740..70c9fb9eb20 100644 --- a/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h +++ b/aten/src/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h @@ -13,6 +13,12 @@ DECLARE_DISPATCH(fp16_gemv_fn, fp16_gemv_trans_stub) using bf16_gemv_fn = void(*)(int, int, BFloat16, const BFloat16*, int, const BFloat16*, int, BFloat16, BFloat16*, int); DECLARE_DISPATCH(bf16_gemv_fn, bf16_gemv_trans_stub) +using fp16_dot_fn = float(*)(const int64_t, const Half*, const int64_t, const Half*, const int64_t); +DECLARE_DISPATCH(fp16_dot_fn, fp16_dot_stub) + +using bf16_dot_fn = float(*)(const int64_t, const BFloat16*, const int64_t, const BFloat16*, const int64_t); +DECLARE_DISPATCH(bf16_dot_fn, bf16_dot_stub) + inline namespace CPU_CAPABILITY { float fp16_dot_with_fp32_arith(const Half* vec1, const Half* vec2, int64_t len); float bf16_dot_with_fp32_arith(const BFloat16* vec1, const BFloat16* vec2, int64_t len);