Add support for int8 x int8 -> int32 matrix multiplication via cublasGemmEx to stream_executor.

PiperOrigin-RevId: 161137741
This commit is contained in:
A. Unique TensorFlower 2017-07-06 15:15:27 -07:00 committed by TensorFlower Gardener
parent 755fa7b501
commit a2ee8bca3f
5 changed files with 91 additions and 18 deletions

View File

@ -97,8 +97,9 @@ enum class ComputationType {
kF16, // 16-bit floating-point kF16, // 16-bit floating-point
kF32, // 32-bit floating-point kF32, // 32-bit floating-point
kF64, // 64-bit floating-point kF64, // 64-bit floating-point
kI32, // 32-bit integer
kComplexF32, // Complex number comprised of two f32s. kComplexF32, // Complex number comprised of two f32s.
kComplexF64 // Complex number comprised of two f64s. kComplexF64, // Complex number comprised of two f64s.
}; };
// Converts a ComputationType to a string. // Converts a ComputationType to a string.
@ -108,6 +109,15 @@ string ComputationTypeString(ComputationType ty);
// as a hint to the blas library. // as a hint to the blas library.
typedef int64 AlgorithmType; typedef int64 AlgorithmType;
// blas uses -1 to represent the default algorithm. This happens to match up
// with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
// to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert
// to ensure that this assumption does not break.
// If another blas implementation uses a different value for the default
// algorithm, then it needs to convert kDefaultGemmAlgo to that value
// (e.g. via a function called ToWhateverGemmAlgo).
constexpr AlgorithmType kDefaultGemmAlgo = -1;
// Describes the result of a performance experiment, usually timing the speed of // Describes the result of a performance experiment, usually timing the speed of
// a particular AlgorithmType. // a particular AlgorithmType.
// //
@ -944,6 +954,12 @@ class BlasSupport {
// output_profile_result->is_valid(). This lets you use this function for // output_profile_result->is_valid(). This lets you use this function for
// choosing the best algorithm among many (some of which may fail) without // choosing the best algorithm among many (some of which may fail) without
// creating a new Stream for each attempt. // creating a new Stream for each attempt.
virtual bool DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int32> *c,
int ldc, ComputationType computation_type, AlgorithmType algorithm,
ProfileResult *output_profile_result) = 0;
virtual bool DoBlasGemmWithAlgorithm( virtual bool DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const Eigen::half &alpha, uint64 n, uint64 k, const Eigen::half &alpha,
@ -1737,6 +1753,13 @@ class BlasSupport {
DeviceMemory<std::complex<double>> *c, int ldc) override; \ DeviceMemory<std::complex<double>> *c, int ldc) override; \
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \ bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
override; \ override; \
bool DoBlasGemmWithAlgorithm( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, \
int lda, const DeviceMemory<int8> &b, int ldb, int beta, \
DeviceMemory<int> *c, int ldc, blas::ComputationType computation_type, \
blas::AlgorithmType algorithm, \
blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemmWithAlgorithm( \ bool DoBlasGemmWithAlgorithm( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, \ uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, \

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_blas.h" #include "tensorflow/stream_executor/cuda/cuda_blas.h"
#include <assert.h>
#include <complex> #include <complex>
#include "tensorflow/stream_executor/cuda/cuda_activation.h" #include "tensorflow/stream_executor/cuda/cuda_activation.h"
@ -483,6 +484,11 @@ struct CUDADataType<std::complex<double>> {
static constexpr cudaDataType_t type = CUDA_C_64F; static constexpr cudaDataType_t type = CUDA_C_64F;
}; };
template <>
struct CUDADataType<int> {
static constexpr cudaDataType_t type = CUDA_R_32I;
};
template <> template <>
struct CUDADataType<int8> { struct CUDADataType<int8> {
static constexpr cudaDataType_t type = CUDA_R_8I; static constexpr cudaDataType_t type = CUDA_R_8I;
@ -511,6 +517,8 @@ cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
return CUDA_R_32F; return CUDA_R_32F;
case blas::ComputationType::kF64: case blas::ComputationType::kF64:
return CUDA_R_64F; return CUDA_R_64F;
case blas::ComputationType::kI32:
return CUDA_R_32I;
case blas::ComputationType::kComplexF32: case blas::ComputationType::kComplexF32:
return CUDA_C_32F; return CUDA_C_32F;
case blas::ComputationType::kComplexF64: case blas::ComputationType::kComplexF64:
@ -1849,12 +1857,12 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
CUDAComplex(CUDAMemoryMutable(c)), ldc); CUDAComplex(CUDAMemoryMutable(c)), ldc);
} }
template <typename T> template <typename InT, typename OutT, typename CompT>
bool CUDABlas::DoBlasGemmWithAlgorithmImpl( bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const T &alpha, const DeviceMemory<T> &a, int lda, uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, int lda,
const DeviceMemory<T> &b, int ldb, const T &beta, DeviceMemory<T> *c, const DeviceMemory<InT> &b, int ldb, const CompT &beta,
int ldc, blas::ComputationType computation_type, DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
// CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx. // CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx.
#if CUDA_VERSION < 8000 #if CUDA_VERSION < 8000
@ -1881,12 +1889,15 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
} }
} }
cudaDataType_t data_type = CUDADataType<T>::type; cudaDataType_t cuda_in_type = CUDADataType<InT>::type;
// Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
// we do the following compile-time check on the default value:
static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
bool result = DoBlasInternalFailureOK( bool result = DoBlasInternalFailureOK(
wrap::cublasGemmEx, stream, /* pointer_mode_host = */ true, wrap::cublasGemmEx, stream, /* pointer_mode_host = */ true,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
CUDAMemory(a), data_type, lda, CUDAMemory(b), data_type, ldb, &beta, CUDAMemory(a), cuda_in_type, lda, CUDAMemory(b), cuda_in_type, ldb, &beta,
CUDAMemoryMutable(c), data_type, ldc, CUDAMemoryMutable(c), CUDADataType<OutT>::type, ldc,
CUDAComputationType(computation_type), CUDAComputationType(computation_type),
static_cast<cublasGemmAlgo_t>(algorithm)); static_cast<cublasGemmAlgo_t>(algorithm));
@ -1920,6 +1931,17 @@ bool CUDABlas::GetBlasGemmAlgorithms(
return true; return true;
} }
bool CUDABlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithAlgorithmImpl(
stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
computation_type, algorithm, output_profile_result);
}
bool CUDABlas::DoBlasGemmWithAlgorithm( bool CUDABlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const Eigen::half &alpha, uint64 n, uint64 k, const Eigen::half &alpha,

View File

@ -118,16 +118,14 @@ class CUDABlas : public blas::BlasSupport {
// and we want to avoid pulling in a dependency on Eigen. When we pass the // and we want to avoid pulling in a dependency on Eigen. When we pass the
// references to cublas, we essentially reinterpret_cast to __half, which is // references to cublas, we essentially reinterpret_cast to __half, which is
// safe because Eigen::half inherits from __half. // safe because Eigen::half inherits from __half.
template <typename T> template <typename InT, typename OutT, typename CompT>
bool DoBlasGemmWithAlgorithmImpl(Stream *stream, blas::Transpose transa, bool DoBlasGemmWithAlgorithmImpl(
blas::Transpose transb, uint64 m, uint64 n, Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 k, const T &alpha, uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a,
const DeviceMemory<T> &a, int lda, int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta,
const DeviceMemory<T> &b, int ldb, DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
const T &beta, DeviceMemory<T> *c, int ldc, blas::AlgorithmType algorithm,
blas::ComputationType computation_type, blas::ProfileResult *output_profile_result);
blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result);
// mutex that guards the cuBLAS handle for this device. // mutex that guards the cuBLAS handle for this device.
mutex mu_; mutex mu_;

View File

@ -3482,6 +3482,27 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
algorithm, output_profile_result); algorithm, output_profile_result);
} }
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
PARAM(algorithm));
ThenBlasWithProfileImpl<
blas::Transpose, blas::Transpose, uint64, uint64, uint64, int,
const DeviceMemory<int8> &, int, const DeviceMemory<int8> &, int, int,
DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
algorithm, output_profile_result);
}
Stream &Stream::ThenBlasGemmWithAlgorithm( Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, float alpha, const DeviceMemory<float> &a, int lda, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,

View File

@ -1257,6 +1257,15 @@ class Stream {
const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc, const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm, blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result); blas::ProfileResult *output_profile_result);
Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n,
uint64 k, int alpha,
const DeviceMemory<int8> &a, int lda,
const DeviceMemory<int8> &b, int ldb,
int beta, DeviceMemory<int> *c, int ldc,
blas::ComputationType computation_type,
blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result);
Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa, Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, float alpha, uint64 k, float alpha,