mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Add support for int8 x int8 -> int32 matrix multiplication via cublasGemmEx to stream_executor.
PiperOrigin-RevId: 161137741
This commit is contained in:
parent
755fa7b501
commit
a2ee8bca3f
|
|
@ -97,8 +97,9 @@ enum class ComputationType {
|
||||||
kF16, // 16-bit floating-point
|
kF16, // 16-bit floating-point
|
||||||
kF32, // 32-bit floating-point
|
kF32, // 32-bit floating-point
|
||||||
kF64, // 64-bit floating-point
|
kF64, // 64-bit floating-point
|
||||||
|
kI32, // 32-bit integer
|
||||||
kComplexF32, // Complex number comprised of two f32s.
|
kComplexF32, // Complex number comprised of two f32s.
|
||||||
kComplexF64 // Complex number comprised of two f64s.
|
kComplexF64, // Complex number comprised of two f64s.
|
||||||
};
|
};
|
||||||
|
|
||||||
// Converts a ComputationType to a string.
|
// Converts a ComputationType to a string.
|
||||||
|
|
@ -108,6 +109,15 @@ string ComputationTypeString(ComputationType ty);
|
||||||
// as a hint to the blas library.
|
// as a hint to the blas library.
|
||||||
typedef int64 AlgorithmType;
|
typedef int64 AlgorithmType;
|
||||||
|
|
||||||
|
// blas uses -1 to represent the default algorithm. This happens to match up
|
||||||
|
// with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
|
||||||
|
// to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert
|
||||||
|
// to ensure that this assumption does not break.
|
||||||
|
// If another blas implementation uses a different value for the default
|
||||||
|
// algorithm, then it needs to convert kDefaultGemmAlgo to that value
|
||||||
|
// (e.g. via a function called ToWhateverGemmAlgo).
|
||||||
|
constexpr AlgorithmType kDefaultGemmAlgo = -1;
|
||||||
|
|
||||||
// Describes the result of a performance experiment, usually timing the speed of
|
// Describes the result of a performance experiment, usually timing the speed of
|
||||||
// a particular AlgorithmType.
|
// a particular AlgorithmType.
|
||||||
//
|
//
|
||||||
|
|
@ -944,6 +954,12 @@ class BlasSupport {
|
||||||
// output_profile_result->is_valid(). This lets you use this function for
|
// output_profile_result->is_valid(). This lets you use this function for
|
||||||
// choosing the best algorithm among many (some of which may fail) without
|
// choosing the best algorithm among many (some of which may fail) without
|
||||||
// creating a new Stream for each attempt.
|
// creating a new Stream for each attempt.
|
||||||
|
virtual bool DoBlasGemmWithAlgorithm(
|
||||||
|
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
|
uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
|
||||||
|
const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int32> *c,
|
||||||
|
int ldc, ComputationType computation_type, AlgorithmType algorithm,
|
||||||
|
ProfileResult *output_profile_result) = 0;
|
||||||
virtual bool DoBlasGemmWithAlgorithm(
|
virtual bool DoBlasGemmWithAlgorithm(
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
uint64 n, uint64 k, const Eigen::half &alpha,
|
uint64 n, uint64 k, const Eigen::half &alpha,
|
||||||
|
|
@ -1737,6 +1753,13 @@ class BlasSupport {
|
||||||
DeviceMemory<std::complex<double>> *c, int ldc) override; \
|
DeviceMemory<std::complex<double>> *c, int ldc) override; \
|
||||||
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
|
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
|
||||||
override; \
|
override; \
|
||||||
|
bool DoBlasGemmWithAlgorithm( \
|
||||||
|
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
||||||
|
uint64 m, uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, \
|
||||||
|
int lda, const DeviceMemory<int8> &b, int ldb, int beta, \
|
||||||
|
DeviceMemory<int> *c, int ldc, blas::ComputationType computation_type, \
|
||||||
|
blas::AlgorithmType algorithm, \
|
||||||
|
blas::ProfileResult *output_profile_result) override; \
|
||||||
bool DoBlasGemmWithAlgorithm( \
|
bool DoBlasGemmWithAlgorithm( \
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
||||||
uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, \
|
uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, \
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "tensorflow/stream_executor/cuda/cuda_blas.h"
|
#include "tensorflow/stream_executor/cuda/cuda_blas.h"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
#include <complex>
|
#include <complex>
|
||||||
|
|
||||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||||
|
|
@ -483,6 +484,11 @@ struct CUDADataType<std::complex<double>> {
|
||||||
static constexpr cudaDataType_t type = CUDA_C_64F;
|
static constexpr cudaDataType_t type = CUDA_C_64F;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct CUDADataType<int> {
|
||||||
|
static constexpr cudaDataType_t type = CUDA_R_32I;
|
||||||
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct CUDADataType<int8> {
|
struct CUDADataType<int8> {
|
||||||
static constexpr cudaDataType_t type = CUDA_R_8I;
|
static constexpr cudaDataType_t type = CUDA_R_8I;
|
||||||
|
|
@ -511,6 +517,8 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
|
||||||
return CUDA_R_32F;
|
return CUDA_R_32F;
|
||||||
case blas::ComputationType::kF64:
|
case blas::ComputationType::kF64:
|
||||||
return CUDA_R_64F;
|
return CUDA_R_64F;
|
||||||
|
case blas::ComputationType::kI32:
|
||||||
|
return CUDA_R_32I;
|
||||||
case blas::ComputationType::kComplexF32:
|
case blas::ComputationType::kComplexF32:
|
||||||
return CUDA_C_32F;
|
return CUDA_C_32F;
|
||||||
case blas::ComputationType::kComplexF64:
|
case blas::ComputationType::kComplexF64:
|
||||||
|
|
@ -1849,12 +1857,12 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
||||||
CUDAComplex(CUDAMemoryMutable(c)), ldc);
|
CUDAComplex(CUDAMemoryMutable(c)), ldc);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename InT, typename OutT, typename CompT>
|
||||||
bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
|
bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
uint64 n, uint64 k, const T &alpha, const DeviceMemory<T> &a, int lda,
|
uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, int lda,
|
||||||
const DeviceMemory<T> &b, int ldb, const T &beta, DeviceMemory<T> *c,
|
const DeviceMemory<InT> &b, int ldb, const CompT &beta,
|
||||||
int ldc, blas::ComputationType computation_type,
|
DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
|
||||||
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
|
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
|
||||||
// CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx.
|
// CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx.
|
||||||
#if CUDA_VERSION < 8000
|
#if CUDA_VERSION < 8000
|
||||||
|
|
@ -1881,12 +1889,15 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaDataType_t data_type = CUDADataType<T>::type;
|
cudaDataType_t cuda_in_type = CUDADataType<InT>::type;
|
||||||
|
// Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
|
||||||
|
// we do the following compile-time check on the default value:
|
||||||
|
static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
|
||||||
bool result = DoBlasInternalFailureOK(
|
bool result = DoBlasInternalFailureOK(
|
||||||
wrap::cublasGemmEx, stream, /* pointer_mode_host = */ true,
|
wrap::cublasGemmEx, stream, /* pointer_mode_host = */ true,
|
||||||
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
|
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
|
||||||
CUDAMemory(a), data_type, lda, CUDAMemory(b), data_type, ldb, &beta,
|
CUDAMemory(a), cuda_in_type, lda, CUDAMemory(b), cuda_in_type, ldb, &beta,
|
||||||
CUDAMemoryMutable(c), data_type, ldc,
|
CUDAMemoryMutable(c), CUDADataType<OutT>::type, ldc,
|
||||||
CUDAComputationType(computation_type),
|
CUDAComputationType(computation_type),
|
||||||
static_cast<cublasGemmAlgo_t>(algorithm));
|
static_cast<cublasGemmAlgo_t>(algorithm));
|
||||||
|
|
||||||
|
|
@ -1920,6 +1931,17 @@ bool CUDABlas::GetBlasGemmAlgorithms(
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool CUDABlas::DoBlasGemmWithAlgorithm(
|
||||||
|
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
|
uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
|
||||||
|
const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
|
||||||
|
int ldc, blas::ComputationType computation_type,
|
||||||
|
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
|
||||||
|
return DoBlasGemmWithAlgorithmImpl(
|
||||||
|
stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||||
|
computation_type, algorithm, output_profile_result);
|
||||||
|
}
|
||||||
|
|
||||||
bool CUDABlas::DoBlasGemmWithAlgorithm(
|
bool CUDABlas::DoBlasGemmWithAlgorithm(
|
||||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
uint64 n, uint64 k, const Eigen::half &alpha,
|
uint64 n, uint64 k, const Eigen::half &alpha,
|
||||||
|
|
|
||||||
|
|
@ -118,16 +118,14 @@ class CUDABlas : public blas::BlasSupport {
|
||||||
// and we want to avoid pulling in a dependency on Eigen. When we pass the
|
// and we want to avoid pulling in a dependency on Eigen. When we pass the
|
||||||
// references to cublas, we essentially reinterpret_cast to __half, which is
|
// references to cublas, we essentially reinterpret_cast to __half, which is
|
||||||
// safe because Eigen::half inherits from __half.
|
// safe because Eigen::half inherits from __half.
|
||||||
template <typename T>
|
template <typename InT, typename OutT, typename CompT>
|
||||||
bool DoBlasGemmWithAlgorithmImpl(Stream *stream, blas::Transpose transa,
|
bool DoBlasGemmWithAlgorithmImpl(
|
||||||
blas::Transpose transb, uint64 m, uint64 n,
|
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
uint64 k, const T &alpha,
|
uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a,
|
||||||
const DeviceMemory<T> &a, int lda,
|
int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta,
|
||||||
const DeviceMemory<T> &b, int ldb,
|
DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
|
||||||
const T &beta, DeviceMemory<T> *c, int ldc,
|
blas::AlgorithmType algorithm,
|
||||||
blas::ComputationType computation_type,
|
blas::ProfileResult *output_profile_result);
|
||||||
blas::AlgorithmType algorithm,
|
|
||||||
blas::ProfileResult *output_profile_result);
|
|
||||||
|
|
||||||
// mutex that guards the cuBLAS handle for this device.
|
// mutex that guards the cuBLAS handle for this device.
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
|
|
|
||||||
|
|
@ -3482,6 +3482,27 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
|
||||||
algorithm, output_profile_result);
|
algorithm, output_profile_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Stream &Stream::ThenBlasGemmWithAlgorithm(
|
||||||
|
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||||
|
uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
|
||||||
|
const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
|
||||||
|
int ldc, blas::ComputationType computation_type,
|
||||||
|
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
|
||||||
|
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
|
||||||
|
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
|
||||||
|
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
|
||||||
|
PARAM(algorithm));
|
||||||
|
|
||||||
|
ThenBlasWithProfileImpl<
|
||||||
|
blas::Transpose, blas::Transpose, uint64, uint64, uint64, int,
|
||||||
|
const DeviceMemory<int8> &, int, const DeviceMemory<int8> &, int, int,
|
||||||
|
DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
|
||||||
|
impl;
|
||||||
|
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
|
||||||
|
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
|
||||||
|
algorithm, output_profile_result);
|
||||||
|
}
|
||||||
|
|
||||||
Stream &Stream::ThenBlasGemmWithAlgorithm(
|
Stream &Stream::ThenBlasGemmWithAlgorithm(
|
||||||
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
|
||||||
uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
|
uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
|
||||||
|
|
|
||||||
|
|
@ -1257,6 +1257,15 @@ class Stream {
|
||||||
const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc,
|
const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc,
|
||||||
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
|
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
|
||||||
blas::ProfileResult *output_profile_result);
|
blas::ProfileResult *output_profile_result);
|
||||||
|
Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa,
|
||||||
|
blas::Transpose transb, uint64 m, uint64 n,
|
||||||
|
uint64 k, int alpha,
|
||||||
|
const DeviceMemory<int8> &a, int lda,
|
||||||
|
const DeviceMemory<int8> &b, int ldb,
|
||||||
|
int beta, DeviceMemory<int> *c, int ldc,
|
||||||
|
blas::ComputationType computation_type,
|
||||||
|
blas::AlgorithmType algorithm,
|
||||||
|
blas::ProfileResult *output_profile_result);
|
||||||
Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa,
|
Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa,
|
||||||
blas::Transpose transb, uint64 m, uint64 n,
|
blas::Transpose transb, uint64 m, uint64 n,
|
||||||
uint64 k, float alpha,
|
uint64 k, float alpha,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user