Automated g4 rollback of changelist 162423171

PiperOrigin-RevId: 162437318
This commit is contained in:
A. Unique TensorFlower 2017-07-18 19:36:18 -07:00 committed by TensorFlower Gardener
parent 9293c557bd
commit 491beb74cc
18 changed files with 196 additions and 1460 deletions

View File

@ -405,7 +405,6 @@ tf_cuda_library(
"util/tensor_slice_reader_cache.h",
"util/tensor_slice_writer.h",
"util/use_cudnn.h",
"util/matmul_autotune.h",
"util/util.h",
"util/work_sharder.h",
] + select({

View File

@ -157,7 +157,6 @@ cc_library(
hdrs = ["conv_2d.h"],
deps = [
":eigen_helpers",
":gpu_util_hdrs",
"//tensorflow/core:framework",
"//third_party/eigen3",
],
@ -249,15 +248,6 @@ cc_library(
],
)
cc_library(
name = "gpu_util_hdrs",
hdrs = ["gpu_utils.h"],
deps = [
":eigen_helpers",
"//third_party/eigen3",
],
)
tf_cc_test(
name = "ops_util_test",
size = "small",
@ -2416,9 +2406,7 @@ tf_kernel_library(
],
"//conditions:default": [],
}),
deps = MATH_DEPS + [
":gpu_util_hdrs",
] + select({
deps = MATH_DEPS + select({
":xsmm": [
"@libxsmm_archive//:xsmm_avx",
],

View File

@ -21,12 +21,27 @@ limitations under the License.
#include <tuple>
#include <unordered_map>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
namespace tensorflow {
// TODO(zhengxq): move this to gpu_util.h. The use of such wrappers is wide
// spread.
template <typename T>
inline perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
uint64 size) {
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
size * sizeof(T));
perftools::gputools::DeviceMemory<T> typed(wrapped);
return typed;
}
// Get the Cudnn workspace limit from the environment variable, which is in MB.
// Return the workspace memory limit in bytes. If no value is set, return the
// default value.
@ -41,10 +56,12 @@ class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator {
virtual ~CudnnScratchAllocator() {}
CudnnScratchAllocator(int64 memory_limit, OpKernelContext* context)
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
virtual int64 GetMemoryLimitInBytes(
perftools::gputools::Stream* stream) override {
return memory_limit_;
}
perftools::gputools::port::StatusOr<perftools::gputools::DeviceMemory<uint8>>
virtual perftools::gputools::port::StatusOr<
perftools::gputools::DeviceMemory<uint8>>
AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override {
Tensor temporary_memory;
if (byte_size > memory_limit_) {
@ -168,6 +185,112 @@ class ConvParameters {
typedef Eigen::GpuDevice GPUDevice;
// A helper class that looks up the best autotuned config from parameters.
// Due to the noisy nature of autotune, especially with multiple devices, it
// only accepts a config if its margin exceeds a threshold.
// For the same shape configs, if a new best config matches the previous best,
// they get promoted; otherwise, the winner gets demoted. This process stops
// when the winner's score exceeds the threshold.
// In a bad case when two configs are very close to each other and flips
// back and forth randomly, the expected number of experiments before autotune
// settles is O(threshold ^ 2). So we recommend that number of warmup runs
// for any benchmarks.
template <typename Parameters, typename Config>
class AutoTuneMap {
public:
bool Find(const Parameters& params, Config* config) const {
mutex_lock lock(mu_);
auto iter = params_config_map_.find(params);
if (iter == params_config_map_.end() ||
iter->second.score < min_score_threshold_) {
return false;
}
*config = iter->second.config;
return true;
}
void Insert(const ConvParameters& params, const Config& config) {
mutex_lock lock(mu_);
auto iter = params_config_map_.find(params);
int new_score = 0;
if (iter == params_config_map_.end()) {
// Create a new entry if params is new.
VLOG(1) << GetActionSummary("creates", params, config);
params_config_map_.insert(std::make_pair(params, ValueType{config, 1}));
new_score = 1;
} else if (iter->second.score < min_score_threshold_) {
DCHECK(iter->second.score > 0);
if (iter->second.config != config) {
// If it is different from the current winner, demotes the winner.
VLOG(1) << GetActionSummary("demotes", params, config);
new_score = --iter->second.score;
if (new_score <= 0) {
VLOG(1) << GetActionSummary("erases", params, config);
params_config_map_.erase(iter);
}
} else {
// If it is the same as the current winner, promotes the winner.
VLOG(1) << GetActionSummary("promotes", params, config);
new_score = ++iter->second.score;
}
}
if (new_score >= min_score_threshold_) {
VLOG(1) << GetActionSummary("accepts", params, config);
}
}
private:
AutoTuneMap(const string& name) : name_(name) {
min_score_threshold_ = 1;
const char* threshold_str = getenv("TF_AUTOTUNE_THRESHOLD");
if (threshold_str != nullptr) {
strings::safe_strto32(threshold_str, &min_score_threshold_);
}
min_score_threshold_ = std::max(min_score_threshold_, 1);
}
template <class Group, class Params, class Cfg>
friend class AutoTuneSingleton;
struct Hasher {
std::size_t operator()(const Parameters& parameter) const {
return parameter.hash();
}
};
string GetActionSummary(StringPiece action, const Parameters& params,
const Config& config) {
return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(),
action.ToString().c_str(), params.ToString().c_str(),
config.ToString().c_str());
}
mutable mutex mu_;
struct ValueType {
Config config;
int32 score;
};
std::unordered_map<Parameters, ValueType, Hasher> params_config_map_
GUARDED_BY(mu_);
string name_;
int32 min_score_threshold_;
TF_DISALLOW_COPY_AND_ASSIGN(AutoTuneMap);
};
// A Singleton helper that manages the global autotune results by groups.
// The caller specified arbitrary Group type that can distinguish between
// different autotune results, even if their Parameters and Configs are the
// same.
template <class Group, typename Parameters, typename Config>
class AutoTuneSingleton {
public:
typedef AutoTuneMap<Parameters, Config> AutoTuneType;
static AutoTuneType* GetInstance() {
static AutoTuneType* instance = new AutoTuneType(Group::name());
return instance;
}
};
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -1,166 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_
#if GOOGLE_CUDA
#include <unordered_map>
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
namespace tensorflow {
template <typename T>
inline perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
uint64 size) {
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
size * sizeof(T));
perftools::gputools::DeviceMemory<T> typed(wrapped);
return typed;
}
// A helper class that looks up the best autotuned config from parameters.
// Due to the noisy nature of autotune, especially with multiple devices, it
// only accepts a config if its margin exceeds a threshold.
// For the same shape configs, if a new best config matches the previous best,
// they get promoted; otherwise, the winner gets demoted. This process stops
// when the winner's score exceeds the threshold.
// In a bad case when two configs are very close to each other and flips
// back and forth randomly, the expected number of experiments before autotune
// settles is O(threshold ^ 2). So we recommend that number of warmup runs
// for any benchmarks.
template <typename Parameters, typename Config>
class AutoTuneMap {
public:
bool Find(const Parameters& params, Config* config) const {
mutex_lock lock(mu_);
auto iter = params_config_map_.find(params);
if (iter == params_config_map_.end() ||
(iter->second.score < min_score_threshold_ &&
iter->second.count <= max_autotune_count_)) {
return false;
}
*config = iter->second.config;
return true;
}
void Insert(const Parameters& params, const Config& config) {
mutex_lock lock(mu_);
auto iter = params_config_map_.find(params);
int new_score = 0;
if (iter == params_config_map_.end()) {
// Create a new entry if params is new.
VLOG(1) << GetActionSummary("creates", params, config);
params_config_map_.insert(
std::make_pair(params, ValueType{config, 1, 1}));
new_score = 1;
} else if (iter->second.score < min_score_threshold_ &&
iter->second.count <= max_autotune_count_) {
DCHECK_GT(iter->second.score, 0);
if (iter->second.config != config) {
// If it is different from the current winner, demotes the winner.
VLOG(1) << GetActionSummary("demotes", params, config);
new_score = --iter->second.score;
++iter->second.count;
if (new_score <= 0) {
VLOG(1) << GetActionSummary("erases", params, config);
params_config_map_.erase(iter);
}
} else {
// If it is the same as the current winner, promotes the winner.
VLOG(1) << GetActionSummary("promotes", params, config);
new_score = ++iter->second.score;
++iter->second.count;
}
}
if (new_score >= min_score_threshold_) {
VLOG(1) << GetActionSummary("accepts", params, config);
}
}
private:
AutoTuneMap(const string& name) : name_(name) {
min_score_threshold_ = 1;
int min_warmup_iterations = 10;
const char* threshold_str = getenv("TF_AUTOTUNE_THRESHOLD");
if (threshold_str != nullptr) {
strings::safe_strto32(threshold_str, &min_score_threshold_);
}
const char* min_warmup_iteration_str =
getenv("TF_AUTOTUNE_MIN_WARMUP_ITERATIONS");
if (min_warmup_iteration_str != nullptr) {
strings::safe_strto32(min_warmup_iteration_str, &min_warmup_iterations);
}
min_score_threshold_ = std::max(min_score_threshold_, 1);
max_autotune_count_ = std::max(
5 * min_score_threshold_ * min_score_threshold_, min_warmup_iterations);
}
template <class Group, class Params, class Cfg>
friend class AutoTuneSingleton;
struct Hasher {
std::size_t operator()(const Parameters& parameter) const {
return parameter.hash();
}
};
string GetActionSummary(StringPiece action, const Parameters& params,
const Config& config) {
return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(),
action.ToString().c_str(), params.ToString().c_str(),
config.ToString().c_str());
}
mutable mutex mu_;
struct ValueType {
Config config;
int32 score;
int32 count;
};
std::unordered_map<Parameters, ValueType, Hasher> params_config_map_
GUARDED_BY(mu_);
string name_;
int32 min_score_threshold_;
int32 max_autotune_count_;
TF_DISALLOW_COPY_AND_ASSIGN(AutoTuneMap);
};
// A Singleton helper that manages the global autotune results by groups.
// The caller specified arbitrary Group type that can distinguish between
// different autotune results, even if their Parameters and Configs are the
// same.
template <class Group, typename Parameters, typename Config>
class AutoTuneSingleton {
public:
typedef AutoTuneMap<Parameters, Config> AutoTuneType;
static AutoTuneType* GetInstance() {
static AutoTuneType* instance = new AutoTuneType(Group::name());
return instance;
}
};
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_

View File

@ -23,15 +23,27 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/util/matmul_autotune.h"
#if GOOGLE_CUDA
#include "cuda/include/cuda.h"
#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
#if GOOGLE_CUDA
namespace {
template <typename T>
perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
perftools::gputools::DeviceMemory<T> typed(wrapped);
return typed;
}
} // namespace
#endif // GOOGLE_CUDA
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
@ -111,16 +123,10 @@ bool ExplicitVectorMatrixOptimization<Eigen::half>(
template <typename Device, typename T>
struct LaunchMatMulBase {
#if GOOGLE_CUDA
typedef perftools::gputools::blas::AlgorithmType AlgorithmType;
#else
typedef int64 AlgorithmType;
#endif // GOOGLE_CUDA
static void launch(
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
std::vector<AlgorithmType>* algorithms, bool use_aututone, Tensor* out) {
Tensor* out) {
#ifndef TENSORFLOW_USE_SYCL
// An explicit vector-matrix multiply is much better optimized than an
// implicit one and this is a bottleneck during non-batched inference.
@ -134,10 +140,6 @@ struct LaunchMatMulBase {
}
#endif // TENSORFLOW_USE_SYCL
}
static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
std::vector<int64>* algorithms,
bool* algorithm_set_flag) {}
};
// On CPUs, we ignore USE_CUBLAS
template <typename T>
@ -157,39 +159,24 @@ struct LaunchMatMul<SYCLDevice, T, USE_CUBLAS> : public LaunchMatMulSYCL<T> {};
#if GOOGLE_CUDA
namespace {
template <typename T>
struct LaunchBlasGemv {
static void Compute(
OpKernelContext* ctx, perftools::gputools::Stream* stream, bool trans,
uint64 m, uint64 n, const perftools::gputools::DeviceMemory<T>& a,
const perftools::gputools::DeviceMemory<T>& b,
perftools::gputools::DeviceMemory<T>* c,
perftools::gputools::blas::ProfileResult* output_profile) {
static void Compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
bool trans, uint64 m, uint64 n,
const perftools::gputools::DeviceMemory<T>& a,
const perftools::gputools::DeviceMemory<T>& b,
perftools::gputools::DeviceMemory<T>* c) {
const auto blas_trans =
trans ? perftools::gputools::blas::Transpose::kTranspose
: perftools::gputools::blas::Transpose::kNoTranspose;
if (output_profile == nullptr) {
bool blas_launch_status =
stream
->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
static_cast<T>(0.0), c, 1)
.ok();
if (!blas_launch_status) {
ctx->SetStatus(
errors::Internal("Blas GEMV launch failed: m=", m, ", n=", n));
}
} else {
bool blas_launch_status =
stream
->ThenBlasGemvWithProfiling(blas_trans, m, n, static_cast<T>(1.0),
a, m, b, 1, static_cast<T>(0.0), c, 1,
output_profile)
.ok();
if (!blas_launch_status) {
ctx->SetStatus(errors::Internal(
"Blas GEMV with profiling launch failed: m=", m, ", n=", n));
}
bool blas_launch_status =
stream
->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
static_cast<T>(0.0), c, 1)
.ok();
if (!blas_launch_status) {
ctx->SetStatus(
errors::Internal("Blas GEMV launch failed: m=", m, ", n=", n));
}
}
@ -201,8 +188,7 @@ void LaunchBlasGemv<Eigen::half>::Compute(
OpKernelContext* ctx, perftools::gputools::Stream* stream, bool trans,
uint64 m, uint64 n, const perftools::gputools::DeviceMemory<Eigen::half>& a,
const perftools::gputools::DeviceMemory<Eigen::half>& b,
perftools::gputools::DeviceMemory<Eigen::half>* c,
perftools::gputools::blas::ProfileResult* output_profile) {
perftools::gputools::DeviceMemory<Eigen::half>* c) {
ctx->SetStatus(errors::Internal(
"Blas GEMV launch failed: GEMV is not implemented for float16."));
}
@ -214,55 +200,15 @@ bool LaunchBlasGemv<Eigen::half>::IsSupported() {
} // namespace
bool GetCublasAutotuneComputationType(
const DataType& dtype,
perftools::gputools::blas::ComputationType* compute_type) {
using perftools::gputools::blas::ComputationType;
bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input();
switch (dtype) {
case DT_HALF:
case DT_BFLOAT16:
if (use_f32_for_f16_computation) {
*compute_type = ComputationType::kF32;
} else {
*compute_type = ComputationType::kF16;
}
return false;
case DT_FLOAT:
*compute_type = ComputationType::kF32;
return true;
case DT_DOUBLE:
*compute_type = ComputationType::kF64;
return true;
default:
// Unsupported compute_type, return false.
return false;
}
}
// A dummy type to group matmul autotune results together.
struct MatmulAutoTuneGroup {
static string name() { return "Matmul"; }
};
typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters,
perftools::gputools::blas::AlgorithmConfig>
AutoTuneMatmul;
template <typename T>
struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
static void launch(
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
std::vector<int64>* algorithms, bool use_autotune, Tensor* out) {
using perftools::gputools::blas::AlgorithmConfig;
using perftools::gputools::blas::ComputationType;
using perftools::gputools::blas::ProfileResult;
using perftools::gputools::blas::Transpose;
using perftools::gputools::blas::kDefaultAlgorithm;
using perftools::gputools::blas::kDefaultBlasGemm;
using perftools::gputools::blas::kDefaultBlasGemv;
using perftools::gputools::blas::kNoAlgorithm;
Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose};
Tensor* out) {
perftools::gputools::blas::Transpose trans[] = {
perftools::gputools::blas::Transpose::kNoTranspose,
perftools::gputools::blas::Transpose::kTranspose};
const uint64 m = a.dim_size(1 - dim_pair[0].first);
const uint64 k = a.dim_size(dim_pair[0].first);
const uint64 n = b.dim_size(1 - dim_pair[0].second);
@ -274,155 +220,34 @@ struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
auto* stream = ctx->op_device_context()->stream();
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
auto a_ptr = AsDeviceMemory(a.template flat<T>().data(),
a.template flat<T>().size());
auto b_ptr = AsDeviceMemory(b.template flat<T>().data(),
b.template flat<T>().size());
auto c_ptr = AsDeviceMemory(out->template flat<T>().data(),
out->template flat<T>().size());
auto alpha = static_cast<T>(1.0);
auto beta = static_cast<T>(0.0);
int device_id = stream->parent()->device_ordinal();
DataType dtype = a.dtype();
MatmulParameters matmul_parameters = {
transpose_a, transpose_b, m, n, k, dtype, device_id,
};
AlgorithmConfig algorithm_config(kNoAlgorithm);
ComputationType computation_type;
bool compute_type_supported =
GetCublasAutotuneComputationType(dtype, &computation_type);
if (use_autotune && compute_type_supported && !algorithms->empty()) {
ProfileResult best_result;
// TODO(yangzihao): Unify this code with conv autotuning.
if (!AutoTuneMatmul::GetInstance()->Find(matmul_parameters,
&algorithm_config)) {
ProfileResult profile_result;
for (auto profile_algorithm : (*algorithms)) {
// Cublas does
// C = A x B
// where A, B and C are assumed to be in column major.
// We want the output to be in row-major, so we can compute
// C' = B' x A' (' stands for transpose)
bool cublas_launch_status =
stream
->ThenBlasGemmWithAlgorithm(
blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
&c_ptr, n, computation_type, profile_algorithm,
&profile_result)
.ok();
if (cublas_launch_status) {
if (profile_result.is_valid()) {
if (profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) {
best_result = profile_result;
}
}
}
}
// Try BlasGemmWithProfiling
bool cublas_launch_status =
stream
->ThenBlasGemmWithProfiling(
blas_transpose_b, blas_transpose_a, n, m, k, 1.0, b_ptr,
transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0,
&c_ptr, n, &profile_result)
.ok();
if (cublas_launch_status) {
if (profile_result.is_valid()) {
if (profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) {
best_result = profile_result;
}
}
}
// Try BlasGemvWithProfiling
if (LaunchBlasGemv<T>::IsSupported() && n == 1) {
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
transpose_a ? m : k, transpose_a ? k : m,
a_ptr, b_ptr, &c_ptr, &profile_result);
if (profile_result.is_valid()) {
if (profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) {
best_result = profile_result;
}
}
}
auto a_ptr = AsDeviceMemory(a.template flat<T>().data());
auto b_ptr = AsDeviceMemory(b.template flat<T>().data());
auto c_ptr = AsDeviceMemory(out->template flat<T>().data());
// Cublas does
// C = A x B
// where A, B and C are assumed to be in column major.
// We want the output to be in row-major, so we can compute
// C' = B' x A' (' stands for transpose)
if (LaunchBlasGemv<T>::IsSupported() && n == 1) {
// This is a matrix*vector multiply so use GEMV to compute A * b.
// Here we are multiplying in the natural order, so we have to flip
// the transposition flag to compensate for the tensor being stored
// row-major.
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a, transpose_a ? m : k,
transpose_a ? k : m, a_ptr, b_ptr, &c_ptr);
} else {
bool blas_launch_status =
stream
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, 1.0f,
b_ptr, transpose_b ? k : n, a_ptr,
transpose_a ? m : k, 0.0f, &c_ptr, n)
.ok();
if (!blas_launch_status) {
ctx->SetStatus(errors::Internal(
"Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
"), m=", m, ", n=", n, ", k=", k));
}
// We make sure that each matmul parameter set only gets one pass of
// autotune. If the best result is found, assign it to algorithm_type
// and insert it to autotune map. If all internal kernels of
// cublasGemmEx() returns invalid results, we add kNoAlgorithm to the
// autotune map.
if (best_result.is_valid()) {
algorithm_config.set_algorithm(best_result.algorithm());
}
AutoTuneMatmul::GetInstance()->Insert(matmul_parameters,
algorithm_config);
if (algorithm_config.algorithm() != kNoAlgorithm &&
algorithm_config.algorithm() != kDefaultBlasGemm &&
algorithm_config.algorithm() != kDefaultBlasGemv) {
bool cublas_launch_status =
stream
->ThenBlasGemmWithAlgorithm(
blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
&c_ptr, n, computation_type, algorithm_config.algorithm(),
nullptr)
.ok();
if (!cublas_launch_status) {
ctx->SetStatus(errors::Internal(
"Blas GEMM with algorithm launch failed : a.shape=(",
a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0),
", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k));
}
}
}
// For the following case, we use normal BlasGemm():
// 1) We didn't set the use_autotune flag;
// 2) compute type does not support autotune;
// 3) no algorithm is found;
// 4) all internal kernels in autotune return invalid results.
if (!use_autotune || !compute_type_supported || algorithms->empty() ||
algorithm_config.algorithm() == kNoAlgorithm ||
algorithm_config.algorithm() == kDefaultBlasGemm ||
algorithm_config.algorithm() == kDefaultBlasGemv) {
if (algorithm_config.algorithm() == kDefaultBlasGemv) {
// This is a matrix*vector multiply so use GEMV to compute A * b.
// Here we are multiplying in the natural order, so we have to flip
// the transposition flag to compensate for the tensor being stored
// row-major.
// TODO(yangzihao): Add Gemv as an autotuning option too.
LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
transpose_a ? m : k, transpose_a ? k : m,
a_ptr, b_ptr, &c_ptr, nullptr);
} else {
// Use C' = B' x A' (' stands for transpose)
bool blas_launch_status =
stream
->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
1.0f, b_ptr, transpose_b ? k : n, a_ptr,
transpose_a ? m : k, 0.0f, &c_ptr, n)
.ok();
if (!blas_launch_status) {
ctx->SetStatus(errors::Internal(
"Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
"), m=", m, ", n=", n, ", k=", k));
}
}
}
}
static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
std::vector<int64>* algorithms,
bool* algorithm_set_flag) {
if (*algorithm_set_flag == false) {
auto* stream = ctx->device()->tensorflow_gpu_device_info()->stream;
stream->parent()->GetBlasGemmAlgorithms(algorithms);
*algorithm_set_flag = true;
}
}
};
@ -432,14 +257,9 @@ struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
template <typename Device, typename T, bool USE_CUBLAS>
class MatMulOp : public OpKernel {
public:
explicit MatMulOp(OpKernelConstruction* ctx)
: OpKernel(ctx), algorithms_set_already_(false) {
explicit MatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm(
ctx, &algorithms_, &algorithms_set_already_);
use_autotune_ = MatmulAutotuneEnable();
}
void Compute(OpKernelContext* ctx) override {
@ -482,14 +302,10 @@ class MatMulOp : public OpKernel {
return;
}
LaunchMatMul<Device, T, USE_CUBLAS>::launch(
ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
LaunchMatMul<Device, T, USE_CUBLAS>::launch(ctx, this, a, b, dim_pair, out);
}
private:
std::vector<int64> algorithms_;
bool algorithms_set_already_;
bool use_autotune_;
bool transpose_a_;
bool transpose_b_;
};

View File

@ -17,9 +17,7 @@ limitations under the License.
#define TENSORFLOW_KERNELS_MATMUL_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
namespace functor {
@ -52,68 +50,6 @@ struct MatMulFunctor {
};
} // end namespace functor
#if GOOGLE_CUDA
// Encapsulate all the shape information that is used in matmul operations.
class MatmulParameters {
public:
MatmulParameters(bool transa, bool transb, uint64 m, uint64 n, uint64 k,
DataType dtype, int device_id)
: transa_(transa),
transb_(transb),
m_(m),
n_(n),
k_(k),
dtype_(dtype),
device_id_(device_id) {
hash_code_ = transa;
hash_code_ = Hash64Combine(hash_code_, transb);
hash_code_ = Hash64Combine(hash_code_, m);
hash_code_ = Hash64Combine(hash_code_, n);
hash_code_ = Hash64Combine(hash_code_, k);
hash_code_ = Hash64Combine(hash_code_, dtype);
hash_code_ = Hash64Combine(hash_code_, device_id);
}
bool operator==(const MatmulParameters& other) const {
return this->get_data_as_tuple() == other.get_data_as_tuple();
}
bool operator!=(const MatmulParameters& other) const {
return !(*this == other);
}
uint64 hash() const { return hash_code_; }
string ToString() const {
// clang-format off
return strings::StrCat(
transa_, ", ", transb_, ", ",
m_, ", ", n_, ", ", k_,
dtype_, ", ", device_id_);
// clang-format on
}
private:
typedef std::tuple<bool, bool, int64, int64, int64, DataType, int>
ParameterDataType;
ParameterDataType get_data_as_tuple() const {
return std::make_tuple(transa_, transb_, m_, n_, k_, dtype_, device_id_);
}
bool transa_;
bool transb_;
uint64 m_;
uint64 n_;
uint64 k_;
DataType dtype_;
int device_id_;
uint64 hash_code_;
};
typedef Eigen::GpuDevice GPUDevice;
#endif // GOOGLE_CUDA
} // end namespace tensorflow
#endif // TENSORFLOW_KERNELS_MATMUL_OP_H_

View File

@ -1,51 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/matmul_autotune.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/util/env_var.h"
namespace tensorflow {
bool MatmulAutotuneEnable() {
bool value;
Status status =
ReadBoolFromEnvVar("TF_MATMUL_AUTOTUNE_ENABLE", false, &value);
if (!status.ok()) {
LOG(ERROR) << status.error_message();
}
return value;
}
bool MatmulDoFP32ComputationFP16Input() {
bool value;
// Feedback from NVIDIA: the "true floating point 16" compute capability is
// absent from compute capability SM 5.2. The native 16 bit floating point
// computation was introduced in SM 5.3 and higher compute capability. So
// for compatibility, set this to be true by default for now.
// TODO(yangzihao): In the future, we need to return three possibilities:
// user-set-true, user-set-false, user-no-setting. In the calling sites,
// check the compatibilities. Note that user-set-false with compute
// capability <= 5.2 will cause an error in the later cublasGemmEx() call.
Status status =
ReadBoolFromEnvVar("TF_FP16_MATMUL_USE_FP32_COMPUTE", true, &value);
if (!status.ok()) {
LOG(ERROR) << status.error_message();
}
return value;
}
} // namespace tensorflow

View File

@ -1,28 +0,0 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// The utility to check matmul autotune related flags.
#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_
#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_
namespace tensorflow {
bool MatmulAutotuneEnable();
bool MatmulDoFP32ComputationFP16Input();
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_MATMUL_AUTOTUNE_H_

View File

@ -3850,48 +3850,6 @@ cuda_py_test(
main = "ops/transpose_benchmark.py",
)
cuda_py_test(
name = "matmul_benchmark",
size = "medium",
srcs = ["ops/matmul_benchmark.py"],
additional_deps = [
":math_ops",
":random_ops",
":client",
":client_testlib",
":control_flow_ops",
":framework_for_generated_wrappers",
":framework_test_lib",
":platform",
":platform_benchmark",
":variables",
"//third_party/py/numpy",
"//tensorflow/core:protos_all_py",
],
main = "ops/matmul_benchmark.py",
)
cuda_py_test(
name = "matmul_benchmark_test",
size = "medium",
srcs = ["ops/matmul_benchmark_test.py"],
additional_deps = [
":math_ops",
":random_ops",
":client",
":client_testlib",
":control_flow_ops",
":framework_for_generated_wrappers",
":platform",
":platform_benchmark",
":matmul_benchmark",
":variables",
"//third_party/py/numpy",
"//tensorflow/core:protos_all_py",
],
main = "ops/matmul_benchmark_test.py",
)
cuda_py_test(
name = "session_benchmark",
srcs = ["client/session_benchmark.py"],

View File

@ -31,9 +31,6 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test as test_lib
# TODO(yangzihao): Currently matmul autotuning is disabled by default. Use
# os.environ["TF_MATMUL_AUTOTUNE_ENABLE"] = "1" to enable it.
def _AddTest(test, op_name, testcase_name, fn):
test_name = "_".join(["test", op_name, testcase_name])

View File

@ -1,143 +0,0 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmark for Matmul operator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import time
import numpy as np
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
def build_graph(device, n, m, k, transpose_a, transpose_b, dtype):
"""Build a graph containing a sequence of matmul operations.
Args:
device: String, the device to run on.
n: tensor A's first dimension size.
m: tensor A's second dimension size.
k: tensor B's second dimension size.
transpose_a: boolean value to show if tensor A is transposed.
transpose_b: boolean value to show if tensor B is transposed.
dtype: numpy data type of the input tensor.
Returns:
A matmul operation to run()
"""
with ops.device('/%s:0' % device):
if not transpose_a:
x = variables.Variable(random_ops.random_uniform([n, m], dtype=dtype))
else:
x = variables.Variable(random_ops.random_uniform([m, n], dtype=dtype))
if not transpose_b:
y = variables.Variable(random_ops.random_uniform([m, k], dtype=dtype))
else:
y = variables.Variable(random_ops.random_uniform([k, m], dtype=dtype))
z = math_ops.matmul(x, y, transpose_a=transpose_a, transpose_b=transpose_b)
return control_flow_ops.group(z)
class MatmulBenchmark(test.Benchmark):
"""Benchmark matmul!"""
def run_graph(self, device, n, m, k, transpose_a, transpose_b, num_iters,
dtype):
"""Run the graph and print its execution time.
Args:
device: String, the device to run on.
n: tensor A's first dimension size.
m: tensor A's second dimension size.
k: tensor B's second dimension size.
transpose_a: boolean value to show if tensor A is transposed.
transpose_b: boolean value to show if tensor B is transposed.
num_iters: number of iterations to run the benchmark.
dtype: numpy data type of the input tensor.
Returns:
The duration of the run in seconds.
"""
graph = ops.Graph()
with graph.as_default():
output = build_graph(device, n, m, k, transpose_a, transpose_b, dtype)
with session_lib.Session(graph=graph) as session:
variables.global_variables_initializer().run()
for _ in range(500):
session.run(output)
start_time = time.time()
for _ in range(num_iters):
session.run(output)
duration = (time.time() - start_time)
num_items = n * m * k * 2
throughput = num_items * num_iters / duration / 1e9
print('%s %s input_info:%s %d %.4fsec, %.4fGitems/s.' %
(device, str(dtype), str(n) + 'x' + str(m) + 'x' + str(k) + ',ta:'
+ str(transpose_a) + '.tb:' + str(transpose_b), num_iters,
duration, throughput))
name_template = ('matmul_{device}_{dtype}_input_info_{inputinfo}')
self.report_benchmark(
name=name_template.format(
device=device,
dtype=str(dtype).replace(' ', ''),
inputinfo=str(n) + 'x' + str(m) + 'x' + str(k) + ',ta:' +
str(transpose_a) + '.tb:' + str(transpose_b)).replace(' ', ''),
iters=num_iters,
wall_time=duration)
return duration
def run_test_gpu(self, n, m, k, transpose_a, transpose_b, dtype, num_iters):
self.run_graph('gpu', n, m, k, transpose_a, transpose_b, num_iters, dtype)
def test_round(self, num_iters):
dtypes = [np.float32, np.float64]
for dtype in dtypes:
for n, m, (transpose_a, transpose_b) in itertools.product(
[512, 1024], [1, 8, 16, 128], [(False, False), (True, False),
(False, True)]):
k = n
self.run_test_gpu(n, m, k, transpose_a, transpose_b, dtype, num_iters)
for n, m, k, (transpose_a, transpose_b) in itertools.product(
[200], [1, 8, 20], [10000], [(False, False), (True, False), (False,
True)]):
self.run_test_gpu(n, m, k, transpose_a, transpose_b, dtype, num_iters)
for (n, m, k), (transpose_a, transpose_b) in itertools.product(
[(200, 20, 20000), (1, 10000, 200)], [(False, False), (True, False),
(False, True)]):
self.run_test_gpu(n, m, k, transpose_a, transpose_b, dtype, num_iters)
def benchmark_matmul(self):
num_iters = 200
for _ in range(10):
self.test_round(num_iters)
if __name__ == '__main__':
test.main()

View File

@ -1,122 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for matmul_benchmark.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import numpy as np
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import ops
from tensorflow.python.ops import matmul_benchmark
from tensorflow.python.platform import test as googletest
from tensorflow.python.platform import tf_logging
def BuildGraphTest(n, m, k, transpose_a, transpose_b, dtype):
def Test(self):
if not googletest.is_gpu_available():
tf_logging.info("Skipping BuildGraphTest %s", (n, m, k, transpose_a,
transpose_b))
return
tf_logging.info("Testing BuildGraphTest %s", (n, m, k, transpose_a,
transpose_b))
self._VerifyBuildGraph(n, m, k, transpose_a, transpose_b, dtype)
return Test
def RunGraphTest(n, m, k, transpose_a, transpose_b, dtype):
def Test(self):
if not googletest.is_gpu_available():
tf_logging.info("Skipping RunGraphTest %s", (n, m, k, transpose_a,
transpose_b))
return
tf_logging.info("Testing RunGraphTest %s", (n, m, k, transpose_a,
transpose_b))
self._VerifyRunGraph(n, m, k, transpose_a, transpose_b, dtype)
return Test
class MatmulBenchmarkTest(googletest.TestCase):
def _StripNode(self, nd):
snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
if nd.device:
snode.device = nd.device
return snode
def _StripGraph(self, gd):
return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
def _VerifyBuildGraph(self, n, m, k, transpose_a, transpose_b, dtype):
graph = ops.Graph()
with graph.as_default():
matmul_benchmark.build_graph("gpu", n, m, k, transpose_a, transpose_b,
dtype)
gd = graph.as_graph_def()
self.assertProtoEquals("""
node { name: "random_uniform/shape" op: "Const" device: "/device:GPU:0" }
node { name: "random_uniform/min" op: "Const" device: "/device:GPU:0" }
node { name: "random_uniform/max" op: "Const" device: "/device:GPU:0" }
node { name: "random_uniform/RandomUniform" op: "RandomUniform" input: "random_uniform/shape" device: "/device:GPU:0" }
node { name: "random_uniform/sub" op: "Sub" input: "random_uniform/max" input: "random_uniform/min" device: "/device:GPU:0" }
node { name: "random_uniform/mul" op: "Mul" input: "random_uniform/RandomUniform" input: "random_uniform/sub" device: "/device:GPU:0" }
node { name: "random_uniform" op: "Add" input: "random_uniform/mul" input: "random_uniform/min" device: "/device:GPU:0" }
node { name: "Variable" op: "VariableV2" device: "/device:GPU:0" }
node { name: "Variable/Assign" op: "Assign" input: "Variable" input: "random_uniform" device: "/device:GPU:0" }
node { name: "Variable/read" op: "Identity" input: "Variable" device: "/device:GPU:0" }
node { name: "random_uniform_1/shape" op: "Const" device: "/device:GPU:0" }
node { name: "random_uniform_1/min" op: "Const" device: "/device:GPU:0" }
node { name: "random_uniform_1/max" op: "Const" device: "/device:GPU:0" }
node { name: "random_uniform_1/RandomUniform" op: "RandomUniform" input: "random_uniform_1/shape" device: "/device:GPU:0" }
node { name: "random_uniform_1/sub" op: "Sub" input: "random_uniform_1/max" input: "random_uniform_1/min" device: "/device:GPU:0" }
node { name: "random_uniform_1/mul" op: "Mul" input: "random_uniform_1/RandomUniform" input: "random_uniform_1/sub" device: "/device:GPU:0" }
node { name: "random_uniform_1" op: "Add" input: "random_uniform_1/mul" input: "random_uniform_1/min" device: "/device:GPU:0" }
node { name: "Variable_1" op: "VariableV2" device: "/device:GPU:0" }
node { name: "Variable_1/Assign" op: "Assign" input: "Variable_1" input: "random_uniform_1" device: "/device:GPU:0" }
node { name: "Variable_1/read" op: "Identity" input: "Variable_1" device: "/device:GPU:0" }
node { name: "MatMul" op: "MatMul" input: "Variable/read" input: "Variable_1/read" device: "/device:GPU:0" }
node { name: "group_deps" op: "NoOp" input: "^MatMul" device: "/device:GPU:0" }
""", self._StripGraph(gd))
def _VerifyRunGraph(self, n, m, k, transpose_a, transpose_b, dtype):
benchmark_instance = matmul_benchmark.MatmulBenchmark()
duration = benchmark_instance.run_graph("gpu", n, m, k, transpose_a,
transpose_b, 1, dtype)
self.assertTrue(duration > 1e-6)
if __name__ == "__main__":
dtypes = [np.float32, np.float64]
index = 0
for _dtype in dtypes:
for _n, _m, (_transpose_a, _transpose_b) in itertools.product(
[512, 1024], [1, 8, 16, 128], [(False, False), (True, False), (False,
True)]):
_k = _n
setattr(MatmulBenchmarkTest, "testBuildGraph_" + str(index),
BuildGraphTest(_n, _m, _k, _transpose_a, _transpose_b, _dtype))
setattr(MatmulBenchmarkTest, "testRunGraph_" + str(index),
RunGraphTest(_n, _m, _k, _transpose_a, _transpose_b, _dtype))
index += 1
googletest.main()

View File

@ -67,10 +67,6 @@ string SideString(Side s) {
}
}
// -- AlgorithmConfig
string AlgorithmConfig::ToString() const { return port::StrCat(algorithm_); }
string ComputationTypeString(ComputationType ty) {
switch (ty) {
case ComputationType::kF16:

View File

@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
#include "tensorflow/stream_executor/platform/port.h"
namespace Eigen {
struct half;
@ -107,10 +108,6 @@ string ComputationTypeString(ComputationType ty);
// Opaque identifier for an "algorithm" used by a blas routine. This functions
// as a hint to the blas library.
typedef int64 AlgorithmType;
constexpr AlgorithmType kDefaultAlgorithm = -1;
constexpr AlgorithmType kDefaultBlasGemm = -2;
constexpr AlgorithmType kDefaultBlasGemv = -3;
constexpr AlgorithmType kNoAlgorithm = -4;
// blas uses -1 to represent the default algorithm. This happens to match up
// with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast
@ -137,28 +134,10 @@ class ProfileResult {
private:
bool is_valid_ = false;
AlgorithmType algorithm_ = kDefaultAlgorithm;
AlgorithmType algorithm_ = 0;
float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
};
class AlgorithmConfig {
public:
AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {}
explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {}
AlgorithmType algorithm() const { return algorithm_; }
void set_algorithm(AlgorithmType val) { algorithm_ = val; }
bool operator==(const AlgorithmConfig &other) const {
return this->algorithm_ == other.algorithm_;
}
bool operator!=(const AlgorithmConfig &other) const {
return !(*this == other);
}
string ToString() const;
private:
AlgorithmType algorithm_;
};
// BLAS support interface -- this can be derived from a GPU executor when the
// underlying platform has an BLAS library implementation available. See
// StreamExecutor::AsBlas().
@ -474,29 +453,6 @@ class BlasSupport {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy) = 0;
virtual bool DoBlasGemvWithProfiling(
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
int incx, float beta, DeviceMemory<float> *y, int incy,
ProfileResult *output_profile_result) = 0;
virtual bool DoBlasGemvWithProfiling(
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
int incx, double beta, DeviceMemory<double> *y, int incy,
ProfileResult *output_profile_result) = 0;
virtual bool DoBlasGemvWithProfiling(
Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
int lda, const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
ProfileResult *output_profile_result) = 0;
virtual bool DoBlasGemvWithProfiling(
Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
int lda, const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
int incy, ProfileResult *output_profile_result) = 0;
// Performs a rank-1 update of a general matrix.
//
// a <- alpha * x * y' + a,
@ -979,39 +935,8 @@ class BlasSupport {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc) = 0;
virtual bool DoBlasGemmWithProfiling(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
DeviceMemory<Eigen::half> *c, int ldc,
ProfileResult *output_profile_result) = 0;
virtual bool DoBlasGemmWithProfiling(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
int ldc, ProfileResult *output_profile_result) = 0;
virtual bool DoBlasGemmWithProfiling(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
const DeviceMemory<double> &b, int ldb, double beta,
DeviceMemory<double> *c, int ldc,
ProfileResult *output_profile_result) = 0;
virtual bool DoBlasGemmWithProfiling(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
ProfileResult *output_profile_result) = 0;
virtual bool DoBlasGemmWithProfiling(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
ProfileResult *output_profile_result) = 0;
// Gets a list of supported algorithms for DoBlasGemmWithAlgorithm.
// Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. Note that
// any or all of these algorithms may still be
virtual bool GetBlasGemmAlgorithms(
std::vector<AlgorithmType> *out_algorithms) = 0;
@ -1548,28 +1473,6 @@ class BlasSupport {
const DeviceMemory<std::complex<double>> &x, int incx, \
std::complex<double> beta, \
DeviceMemory<std::complex<double>> *y, int incy) override; \
bool DoBlasGemvWithProfiling( \
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, \
const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, \
int incx, float beta, DeviceMemory<float> *y, int incy, \
blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemvWithProfiling( \
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, \
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, \
int incx, double beta, DeviceMemory<double> *y, int incy, \
blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemvWithProfiling( \
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, \
int lda, const DeviceMemory<std::complex<float>> &x, int incx, \
std::complex<float> beta, DeviceMemory<std::complex<float>> *y, \
int incy, blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemvWithProfiling( \
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \
std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \
int lda, const DeviceMemory<std::complex<double>> &x, int incx, \
std::complex<double> beta, DeviceMemory<std::complex<double>> *y, \
int incy, blas::ProfileResult *output_profile_result) override; \
bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, \
const DeviceMemory<float> &x, int incx, \
const DeviceMemory<float> &y, int incy, \
@ -1848,39 +1751,6 @@ class BlasSupport {
const DeviceMemory<std::complex<double>> &b, int ldb, \
std::complex<double> beta, \
DeviceMemory<std::complex<double>> *c, int ldc) override; \
bool DoBlasGemmWithProfiling( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, float alpha, \
const DeviceMemory<Eigen::half> &a, int lda, \
const DeviceMemory<Eigen::half> &b, int ldb, float beta, \
DeviceMemory<Eigen::half> *c, int ldc, \
blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemmWithProfiling( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
int lda, const DeviceMemory<float> &b, int ldb, float beta, \
DeviceMemory<float> *c, int ldc, \
blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemmWithProfiling( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, double alpha, \
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \
int ldb, double beta, DeviceMemory<double> *c, int ldc, \
blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemmWithProfiling( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \
const DeviceMemory<std::complex<float>> &a, int lda, \
const DeviceMemory<std::complex<float>> &b, int ldb, \
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemmWithProfiling( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \
const DeviceMemory<std::complex<double>> &a, int lda, \
const DeviceMemory<std::complex<double>> &b, int ldb, \
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \
int ldc, blas::ProfileResult *output_profile_result) override; \
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \
override; \
bool DoBlasGemmWithAlgorithm( \

View File

@ -1857,180 +1857,6 @@ bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
CUDAComplex(CUDAMemoryMutable(c)), ldc);
}
bool CUDABlas::DoBlasGemvWithProfiling(
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
int incx, float beta, DeviceMemory<float> *y, int incy,
blas::ProfileResult *output_profile_result) {
return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
incx, beta, y, incy,
output_profile_result);
}
bool CUDABlas::DoBlasGemvWithProfiling(
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
int incx, double beta, DeviceMemory<double> *y, int incy,
blas::ProfileResult *output_profile_result) {
return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
incx, beta, y, incy,
output_profile_result);
}
bool CUDABlas::DoBlasGemvWithProfiling(
Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
int lda, const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
blas::ProfileResult *output_profile_result) {
return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
incx, beta, y, incy,
output_profile_result);
}
bool CUDABlas::DoBlasGemvWithProfiling(
Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
int lda, const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
blas::ProfileResult *output_profile_result) {
return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
incx, beta, y, incy,
output_profile_result);
}
bool CUDABlas::DoBlasGemmWithProfiling(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
DeviceMemory<Eigen::half> *c, int ldc,
blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
lda, b, ldb, beta, c, ldc,
output_profile_result);
}
bool CUDABlas::DoBlasGemmWithProfiling(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
int ldc, blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
lda, b, ldb, beta, c, ldc,
output_profile_result);
}
bool CUDABlas::DoBlasGemmWithProfiling(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
const DeviceMemory<double> &b, int ldb, double beta,
DeviceMemory<double> *c, int ldc,
blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
lda, b, ldb, beta, c, ldc,
output_profile_result);
}
bool CUDABlas::DoBlasGemmWithProfiling(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
lda, b, ldb, beta, c, ldc,
output_profile_result);
}
bool CUDABlas::DoBlasGemmWithProfiling(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
lda, b, ldb, beta, c, ldc,
output_profile_result);
}
template <typename T>
bool CUDABlas::DoBlasGemvWithProfilingImpl(
Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha,
const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
const T &beta, DeviceMemory<T> *y, int incy,
blas::ProfileResult *output_profile_result) {
struct TimerDeleter {
void operator()(CUDATimer *t) {
t->Destroy();
delete t;
}
};
std::unique_ptr<CUDATimer, TimerDeleter> timer;
if (output_profile_result != nullptr) {
timer.reset(new CUDATimer(parent_));
if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
return false;
}
}
// Call blasGemm
bool result =
DoBlasGemv(stream, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
if (timer != nullptr && result) {
// CUDATimer will CHECK-fail if we Stop() it while the stream is in an error
// state.
if (!timer->Stop(AsCUDAStream(stream))) {
return false;
}
output_profile_result->set_is_valid(true);
output_profile_result->set_algorithm(blas::kDefaultBlasGemv);
output_profile_result->set_elapsed_time_in_ms(
timer->GetElapsedMilliseconds());
}
return result;
}
template <typename T, typename ParamType>
bool CUDABlas::DoBlasGemmWithProfilingImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
struct TimerDeleter {
void operator()(CUDATimer *t) {
t->Destroy();
delete t;
}
};
std::unique_ptr<CUDATimer, TimerDeleter> timer;
if (output_profile_result != nullptr) {
timer.reset(new CUDATimer(parent_));
if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
return false;
}
}
// Call blasGemm
bool result = DoBlasGemm(stream, transa, transb, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc);
if (timer != nullptr && result) {
// CUDATimer will CHECK-fail if we Stop() it while the stream is in an error
// state.
if (!timer->Stop(AsCUDAStream(stream))) {
return false;
}
output_profile_result->set_is_valid(true);
output_profile_result->set_algorithm(blas::kDefaultBlasGemm);
output_profile_result->set_elapsed_time_in_ms(
timer->GetElapsedMilliseconds());
}
return result;
}
template <typename InT, typename OutT, typename CompT>
bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
@ -2094,9 +1920,6 @@ bool CUDABlas::GetBlasGemmAlgorithms(
std::vector<blas::AlgorithmType> *out_algorithms) {
// cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
// were first introduced in CUDA 8.
// Note that when CUDA version and compute capability is not sufficient, we
// still return the out_algorithms. Caller needs to make sure that in this case,
// the returned vector is empty.
#if CUDA_VERSION >= 8000
for (cublasGemmAlgo_t algo :
{CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
@ -2104,10 +1927,8 @@ bool CUDABlas::GetBlasGemmAlgorithms(
CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7}) {
out_algorithms->push_back(algo);
}
return true;
#else
return false;
#endif
return true;
}
bool CUDABlas::DoBlasGemmWithAlgorithm(

View File

@ -127,23 +127,6 @@ class CUDABlas : public blas::BlasSupport {
blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result);
// Helper function for implementing DoBlasGemmWithProfiling.
template <typename T, typename ParamType>
bool DoBlasGemmWithProfilingImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result);
// Helper function for implementing DoBlasGemvWithProfiling.
template <typename T>
bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
uint64 m, uint64 n, const T &alpha,
const DeviceMemory<T> &a, int lda,
const DeviceMemory<T> &x, int incx,
const T &beta, DeviceMemory<T> *y, int incy,
blas::ProfileResult *output_profile_result);
// mutex that guards the cuBLAS handle for this device.
mutex mu_;

View File

@ -3458,184 +3458,6 @@ struct ThenBlasWithProfileImpl {
};
} // anonymous namespace
Stream &Stream::ThenBlasGemvWithProfiling(
blas::Transpose trans, uint64 m, uint64 n, float alpha,
const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
int incx, float beta, DeviceMemory<float> *y, int incy,
blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
PARAM(incy));
ThenBlasWithProfileImpl<
blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int,
const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
}
Stream &Stream::ThenBlasGemvWithProfiling(
blas::Transpose trans, uint64 m, uint64 n, double alpha,
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
int incx, double beta, DeviceMemory<double> *y, int incy,
blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
PARAM(incy));
ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double,
const DeviceMemory<double> &, int,
const DeviceMemory<double> &, int, double,
DeviceMemory<double> *, int>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
}
Stream &Stream::ThenBlasGemvWithProfiling(
blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
PARAM(incy));
ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>,
const DeviceMemory<std::complex<float>> &, int,
const DeviceMemory<std::complex<float>> &, int,
std::complex<float>,
DeviceMemory<std::complex<float>> *, int>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
}
Stream &Stream::ThenBlasGemvWithProfiling(
blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
PARAM(incy));
ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>,
const DeviceMemory<std::complex<double>> &, int,
const DeviceMemory<std::complex<double>> &, int,
std::complex<double>,
DeviceMemory<std::complex<double>> *, int>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
}
Stream &Stream::ThenBlasGemmWithProfiling(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
const DeviceMemory<Eigen::half> &b, int ldb, float beta,
DeviceMemory<Eigen::half> *c, int ldc,
blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc));
ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
uint64, float, const DeviceMemory<Eigen::half> &, int,
const DeviceMemory<Eigen::half> &, int, float,
DeviceMemory<Eigen::half> *, int>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
output_profile_result);
}
Stream &Stream::ThenBlasGemmWithProfiling(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
int ldc, blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc));
ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
uint64, float, const DeviceMemory<float> &, int,
const DeviceMemory<float> &, int, float,
DeviceMemory<float> *, int>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
output_profile_result);
}
Stream &Stream::ThenBlasGemmWithProfiling(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
const DeviceMemory<double> &b, int ldb, double beta,
DeviceMemory<double> *c, int ldc,
blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc));
ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
uint64, double, const DeviceMemory<double> &, int,
const DeviceMemory<double> &, int, double,
DeviceMemory<double> *, int>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
output_profile_result);
}
Stream &Stream::ThenBlasGemmWithProfiling(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc));
ThenBlasWithProfileImpl<
blas::Transpose, blas::Transpose, uint64, uint64, uint64,
std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
DeviceMemory<std::complex<float>> *, int>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
output_profile_result);
}
Stream &Stream::ThenBlasGemmWithProfiling(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc));
ThenBlasWithProfileImpl<
blas::Transpose, blas::Transpose, uint64, uint64, uint64,
std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
DeviceMemory<std::complex<double>> *, int>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
output_profile_result);
}
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a,

View File

@ -934,31 +934,6 @@ class Stream {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *y, int incy);
Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
float alpha, const DeviceMemory<float> &a,
int lda, const DeviceMemory<float> &x,
int incx, float beta,
DeviceMemory<float> *y, int incy,
blas::ProfileResult *output_profile_result);
Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
double alpha, const DeviceMemory<double> &a,
int lda, const DeviceMemory<double> &x,
int incx, double beta,
DeviceMemory<double> *y, int incy,
blas::ProfileResult *output_profile_result);
Stream &ThenBlasGemvWithProfiling(
blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &x, int incx,
std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
blas::ProfileResult *output_profile_result);
Stream &ThenBlasGemvWithProfiling(
blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &x, int incx,
std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
int incy, blas::ProfileResult *output_profile_result);
// See BlasSupport::DoBlasGer.
Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
const DeviceMemory<float> &x, int incx,
@ -1274,44 +1249,6 @@ class Stream {
std::complex<double> beta,
DeviceMemory<std::complex<double>> *c, int ldc);
Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n,
uint64 k, float alpha,
const DeviceMemory<Eigen::half> &a, int lda,
const DeviceMemory<Eigen::half> &b, int ldb,
float beta, DeviceMemory<Eigen::half> *c,
int ldc,
blas::ProfileResult *output_profile_result);
Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n,
uint64 k, float alpha,
const DeviceMemory<float> &a, int lda,
const DeviceMemory<float> &b, int ldb,
float beta, DeviceMemory<float> *c, int ldc,
blas::ProfileResult *output_profile_result);
Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n,
uint64 k, double alpha,
const DeviceMemory<double> &a, int lda,
const DeviceMemory<double> &b, int ldb,
double beta, DeviceMemory<double> *c,
int ldc,
blas::ProfileResult *output_profile_result);
Stream &ThenBlasGemmWithProfiling(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, std::complex<float> alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
blas::ProfileResult *output_profile_result);
Stream &ThenBlasGemmWithProfiling(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
uint64 k, std::complex<double> alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
blas::ProfileResult *output_profile_result);
// See BlasSupport::DoBlasGemmWithAlgorithm.
Stream &ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,